To give motivation for this paper, see the demo released by Nvidia.

What is Semantic Image Synthesis?

It is the opposite of image segmentation. Here we take a segmentation map (seg map)and our aim is to produce a colored picture for that segmentation map. In segmentation tasks, each color value in the seg map corresponds to a particular class.

New things in the paper

SPADE paper introduces a new normalization technique called spatially-adaptive normalization. Earlier models used the seg map only at the input layer but as seg map was only available in one layer the information contained in the seg map washed away in the deeper layers. SPADE solves this problem. In SPADE, we give seg map as input to all the intermediate layers.

How to train the model?

Before getting into the details of the model, I would discuss how models are trained for a task like Semantic Image Synthesis.

The core idea behind the model training is a GAN. Why GAN is needed? Because whenever we want to generate something that looks photorealistic or more technically closer to the output images, we have to use GANs.

So for GAN we need three things 1) Generator 2) Discriminator 3) Loss Function. For the Generator, we need to input some random values. Now you can either take random normal values. But if you want your output image to resemble some other image i.e. take the style of some image and add it your output image, you will also need an image encoder which would provide the mean and variance values for the random Gaussian distribution.

For the loss function, we would use the loss function used in pix2pixHD paper with some modifications. Also, I would discuss this technique where we extract features from the VGG model and then compute loss function (perceptual loss).

SPADE

This is the basic block that we would use.

How to resize segmentation map?

Every pixel value in your seg map corresponds to a class and you cannot introduce new pixel values. When we use the defaults in various libraries for resizing, we do some form of interpolation like linear, which can change up the pixel values and result in values that were not there before. To solve this problem, whenever you have to resize your segmentation map use ‘nearest’ as the upsampling or downsampling method.

How we use it? Consider some layer in your model, you want to add the information from the segmentation map to the output of that layer. That will be done using SPADE.

SPADE first resizes your seg map to match the size of the features and then we apply a conv layer to the resized seg map to extract the features. To normalize our feature map, we first normalize our feature map using BatchNorm and then denormalize using the values we get from the seg map.

class SPADE(Module):
    def __init__(self, args, k):
        super().__init__()
        num_filters = args.spade_filter
        kernel_size = args.spade_kernel
        self.conv = spectral_norm(Conv2d(1, num_filters, kernel_size=(kernel_size, kernel_size), padding=1))
        self.conv_gamma = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
        self.conv_beta = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))

    def forward(self, x, seg):
        N, C, H, W = x.size()

        sum_channel = torch.sum(x.reshape(N, C, H*W), dim=-1)
        mean = sum_channel / (N*H*W)
        std = torch.sqrt((sum_channel**2 - mean**2) / (N*H*W))

        mean = torch.unsqueeze(torch.unsqueeze(mean, -1), -1)
        std = torch.unsqueeze(torch.unsqueeze(std, -1), -1)
        x = (x - mean) / std

        seg = F.interpolate(seg, size=(H,W), mode='nearest')
        seg = relu(self.conv(seg))
        seg_gamma = self.conv_gamma(seg)
        seg_beta = self.conv_beta(seg)

        x = torch.matmul(seg_gamma, x) + seg_beta

        return x

SPADERes Block

Just like Resnet where we combine conv layers into a ResNet Block, we combine SPADE into a SPADEResBlk.

The idea is simple we are just extending the ResNet block. The skip-connection is important as it allows for training of deeper networks and we do not have to suffer from problems of vanishing gradients.

class SPADEResBlk(Module):
    def __init__(self, args, k, skip=False):
        super().__init__()
        kernel_size = args.spade_resblk_kernel
        self.skip = skip
        
        if self.skip:
            self.spade1 = SPADE(args, 2*k)
            self.conv1 = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
            self.spade_skip = SPADE(args, 2*k)
            self.conv_skip = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
        else:
            self.spade1 = SPADE(args, k)
            self.conv1 = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)

        self.spade2 = SPADE(args, k)
        self.conv2 = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
    
    def forward(self, x, seg):
        x_skip = x
        x = relu(self.spade1(x, seg))
        x = self.conv1(x)
        x = relu(self.spade2(x, seg))
        x = self.conv2(x)

        if self.skip:
            x_skip = relu(self.spade_skip(x_skip, seg))
            x_skip = self.conv_skip(x_skip)
        
        return x_skip + x

Now we have our basic blocks, we start coding up our GAN. Again, the three things that we need for GAN are:

  1. Generator
  2. Discriminator
  3. Loss Function

Generator

class SPADEGenerator(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.linear = Linear(args.gen_input_size, args.gen_hidden_size)
        self.spade_resblk1 = SPADEResBlk(args, 1024)
        self.spade_resblk2 = SPADEResBlk(args, 1024)
        self.spade_resblk3 = SPADEResBlk(args, 1024)
        self.spade_resblk4 = SPADEResBlk(args, 512)
        self.spade_resblk5 = SPADEResBlk(args, 256)
        self.spade_resblk6 = SPADEResBlk(args, 128)
        self.spade_resblk7 = SPADEResBlk(args, 64)
        self.conv = spectral_norm(Conv2d(64, 3, kernel_size=(3,3), padding=1))

    def forward(self, x, seg):
        b, c, h, w = seg.size()
        x = self.linear(x)
        x = x.view(b, -1, 4, 4)

        x = interpolate(self.spade_resblk1(x, seg), size=(2*h, 2*w), mode='nearest')
        x = interpolate(self.spade_resblk2(x, seg), size=(4*h, 4*w), mode='nearest')
        x = interpolate(self.spade_resblk3(x, seg), size=(8*h, 8*w), mode='nearest')
        x = interpolate(self.spade_resblk4(x, seg), size=(16*h, 16*w), mode='nearest')
        x = interpolate(self.spade_resblk5(x, seg), size=(32*h, 32*w), mode='nearest')
        x = interpolate(self.spade_resblk6(x, seg), size=(64*h, 64*w), mode='nearest')
        x = interpolate(self.spade_resblk7(x, seg), size=(128*h, 128*w), mode='nearest')
        
        x = tanh(self.conv(x))

        return x

Discriminator

def custom_model1(in_chan, out_chan):
    return nn.Sequential(
        spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=2, padding=1)),
        nn.LeakyReLU(inplace=True)
    )

def custom_model2(in_chan, out_chan, stride=2):
    return nn.Sequential(
        spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=stride, padding=1)),
        nn.InstanceNorm2d(out_chan),
        nn.LeakyReLU(inplace=True)
    )

class SPADEDiscriminator(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.layer1 = custom_model1(4, 64)
        self.layer2 = custom_model2(64, 128)
        self.layer3 = custom_model2(128, 256)
        self.layer4 = custom_model2(256, 512, stride=1)
        self.inst_norm = nn.InstanceNorm2d(512)
        self.conv = spectral_norm(nn.Conv2d(512, 1, kernel_size=(4,4), padding=1))

    def forward(self, img, seg):
        x = torch.cat((seg, img.detach()), dim=1)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = leaky_relu(self.inst_norm(x))
        x = self.conv(x)
        return x

Loss function

The most important piece for training a GAN. We are all familiar with the loss function of minimizing the Generator and maximizing the discriminator, where the objective function looks something like this.

$$\mathbb{E}_{(\boldsymbol{\mathrm{s}},\boldsymbol{\mathrm{x}})}[\log D(\boldsymbol{\mathrm{s}},\boldsymbol{\mathrm{x}})]+\mathbb{E}_{\boldsymbol{\mathrm{s}}}[\log (1-D(\boldsymbol{\mathrm{s}},G(\boldsymbol{\mathrm{s}})$$

Now we extend this loss function to a feature matching loss. What do I mean? When we compute this loss function we are only computing the values on a fixed size of the image, but what if we compute the losses at different sizes of the image and then sum them all.

This loss would stabilize training as the generator has to produce natural statistics at multiple scales. To do so, we extract features from multiple layers of the discriminator and learn to match these intermediate representations from the real and the synthesized images. This is done by taking features out of a pretrained VGG model. This is called perceptual loss. The code makes it easier to understand.

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = VGG19().cuda()
        self.criterion = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss

So we take the two images, real and synthesized and pass it through VGG network. We compare the intermediate feature maps to compute the loss. We can also use ResNet, but VGG works pretty good and earlier layers of VGG are generally good at extracting the features of an image.

This is not the complete loss function. Below I show my implementation without the perceptual loss. I strongly recommend seeing the loss function implementation used by Nvidia themselves for this project as it combines the above loss also and it would also provide a general guideline on how to train GANs in 2019.

class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super().__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.L1Loss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = torch.tensor(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = torch.tensor(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):        
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor.to(torch.device('cuda')))

Weight Init

In the paper, they used Glorot Initialization (another name of Xavier initialization). I prefer to use He. Initialization.

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

Image Encoder

This is the final part of our model. It is used if you want to transfer style from one image to the output of SPADE. It works by outputting the mean and variance values from which we compute the random gaussian noise that we input to the generator.

def conv_inst_lrelu(in_chan, out_chan):
    return nn.Sequential(
        nn.Conv2d(in_chan, out_chan, kernel_size=(3,3), stride=2, bias=False, padding=1),
        nn.InstanceNorm2d(out_chan),
        nn.LeakyReLU(inplace=True)
    )

class SPADEEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.layer1 = conv_inst_lrelu(3, 64)
        self.layer2 = conv_inst_lrelu(64, 128)
        self.layer3 = conv_inst_lrelu(128, 256)
        self.layer4 = conv_inst_lrelu(256, 512)
        self.layer5 = conv_inst_lrelu(512, 512)
        self.layer6 = conv_inst_lrelu(512, 512)
        self.linear_mean = nn.Linear(8192, args.gen_input_size)
        self.linear_var = nn.Linear(8192, args.gen_input_size)
        

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = x.view(x.size(0), -1)

        return self.linear_mean(x), self.linear_var(x)

Why Spectral Normalization?

Spectral Normalization Explained by Christian Cosgrove. This article discusses spectral norm in detail with all the maths behind it. Ian Goodfellow even commented on spectral normalization and considers it to be an important tool.

The reason we need spectral norm is that when we are generating images, it can become a problem to train our model to generate images of say 1000 categories on ImageNet. Spectral Norm helps by stabilizing the training of discriminator. There are theoretical justifications behind this, on why this should be done, but all that is beautifully explained in the above blog post that I linked to.

To use spectral norm in your model, just apply spectral_norm to all convolutional layers in your generator and discriminator.

Brief Discussion on Instance normalization

Batch Normalization uses the complete batch to compute the mean and std and then normalizes the complete batch with a single value of mean and std. This is good when we are doing classification, but when we are generating images, we want to keep the normalization of these images independent.

One simple reason for that is if in my batch one image is being generated for blue sky and in another image, generating a road then clearly normalizing these with the same mean and std would add extra noise to the images, which would make training worse. So instance norm is used instead of batch normalization here.

Resources