SPADE: State of the art in Image-to-Image Translation by Nvidia
New state of the art method for generating colored images from segmentation masks. It uses a GAN to learn to produce photorealistic images.
- What is Semantic Image Synthesis?
- New things in the paper
- How to train the model?
- SPADE
- How to resize segmentation map?
- SPADERes Block
- Generator
- Discriminator
- Loss function
- Weight Init
- Image Encoder
- Why Spectral Normalization?
- Resources
Link to implementation code, paper
To give motivation for this paper, see the demo released by Nvidia.
New things in the paper
SPADE paper introduces a new normalization technique called spatially-adaptive normalization. Earlier models used the seg map only at the input layer but as seg map was only available in one layer the information contained in the seg map washed away in the deeper layers. SPADE solves this problem. In SPADE, we give seg map as input to all the intermediate layers.
How to train the model?
Before getting into the details of the model, I would discuss how models are trained for a task like Semantic Image Synthesis.
The core idea behind the model training is a GAN. Why GAN is needed? Because whenever we want to generate something that looks photorealistic or more technically closer to the output images, we have to use GANs.
So for GAN we need three things 1) Generator 2) Discriminator 3) Loss Function. For the Generator, we need to input some random values. Now you can either take random normal values. But if you want your output image to resemble some other image i.e. take the style of some image and add it your output image, you will also need an image encoder which would provide the mean and variance values for the random Gaussian distribution.
For the loss function, we would use the loss function used in pix2pixHD paper with some modifications. Also, I would discuss this technique where we extract features from the VGG model and then compute loss function (perceptual loss).
How to resize segmentation map?
Every pixel value in your seg map corresponds to a class and you cannot introduce new pixel values. When we use the defaults in various libraries for resizing, we do some form of interpolation like linear, which can change up the pixel values and result in values that were not there before. To solve this problem, whenever you have to resize your segmentation map use ‘nearest’ as the upsampling or downsampling method.
How we use it? Consider some layer in your model, you want to add the information from the segmentation map to the output of that layer. That will be done using SPADE.
SPADE first resizes your seg map to match the size of the features and then we apply a conv layer to the resized seg map to extract the features. To normalize our feature map, we first normalize our feature map using BatchNorm and then denormalize using the values we get from the seg map.
class SPADE(Module):
def __init__(self, args, k):
super().__init__()
num_filters = args.spade_filter
kernel_size = args.spade_kernel
self.conv = spectral_norm(Conv2d(1, num_filters, kernel_size=(kernel_size, kernel_size), padding=1))
self.conv_gamma = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
self.conv_beta = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
def forward(self, x, seg):
N, C, H, W = x.size()
sum_channel = torch.sum(x.reshape(N, C, H*W), dim=-1)
mean = sum_channel / (N*H*W)
std = torch.sqrt((sum_channel**2 - mean**2) / (N*H*W))
mean = torch.unsqueeze(torch.unsqueeze(mean, -1), -1)
std = torch.unsqueeze(torch.unsqueeze(std, -1), -1)
x = (x - mean) / std
seg = F.interpolate(seg, size=(H,W), mode='nearest')
seg = relu(self.conv(seg))
seg_gamma = self.conv_gamma(seg)
seg_beta = self.conv_beta(seg)
x = torch.matmul(seg_gamma, x) + seg_beta
return x
The idea is simple we are just extending the ResNet block. The skip-connection is important as it allows for training of deeper networks and we do not have to suffer from problems of vanishing gradients.
class SPADEResBlk(Module):
def __init__(self, args, k, skip=False):
super().__init__()
kernel_size = args.spade_resblk_kernel
self.skip = skip
if self.skip:
self.spade1 = SPADE(args, 2*k)
self.conv1 = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
self.spade_skip = SPADE(args, 2*k)
self.conv_skip = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
else:
self.spade1 = SPADE(args, k)
self.conv1 = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
self.spade2 = SPADE(args, k)
self.conv2 = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
def forward(self, x, seg):
x_skip = x
x = relu(self.spade1(x, seg))
x = self.conv1(x)
x = relu(self.spade2(x, seg))
x = self.conv2(x)
if self.skip:
x_skip = relu(self.spade_skip(x_skip, seg))
x_skip = self.conv_skip(x_skip)
return x_skip + x
Now we have our basic blocks, we start coding up our GAN. Again, the three things that we need for GAN are:
- Generator
- Discriminator
- Loss Function
class SPADEGenerator(nn.Module):
def __init__(self, args):
super().__init__()
self.linear = Linear(args.gen_input_size, args.gen_hidden_size)
self.spade_resblk1 = SPADEResBlk(args, 1024)
self.spade_resblk2 = SPADEResBlk(args, 1024)
self.spade_resblk3 = SPADEResBlk(args, 1024)
self.spade_resblk4 = SPADEResBlk(args, 512)
self.spade_resblk5 = SPADEResBlk(args, 256)
self.spade_resblk6 = SPADEResBlk(args, 128)
self.spade_resblk7 = SPADEResBlk(args, 64)
self.conv = spectral_norm(Conv2d(64, 3, kernel_size=(3,3), padding=1))
def forward(self, x, seg):
b, c, h, w = seg.size()
x = self.linear(x)
x = x.view(b, -1, 4, 4)
x = interpolate(self.spade_resblk1(x, seg), size=(2*h, 2*w), mode='nearest')
x = interpolate(self.spade_resblk2(x, seg), size=(4*h, 4*w), mode='nearest')
x = interpolate(self.spade_resblk3(x, seg), size=(8*h, 8*w), mode='nearest')
x = interpolate(self.spade_resblk4(x, seg), size=(16*h, 16*w), mode='nearest')
x = interpolate(self.spade_resblk5(x, seg), size=(32*h, 32*w), mode='nearest')
x = interpolate(self.spade_resblk6(x, seg), size=(64*h, 64*w), mode='nearest')
x = interpolate(self.spade_resblk7(x, seg), size=(128*h, 128*w), mode='nearest')
x = tanh(self.conv(x))
return x
def custom_model1(in_chan, out_chan):
return nn.Sequential(
spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=2, padding=1)),
nn.LeakyReLU(inplace=True)
)
def custom_model2(in_chan, out_chan, stride=2):
return nn.Sequential(
spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=stride, padding=1)),
nn.InstanceNorm2d(out_chan),
nn.LeakyReLU(inplace=True)
)
class SPADEDiscriminator(nn.Module):
def __init__(self, args):
super().__init__()
self.layer1 = custom_model1(4, 64)
self.layer2 = custom_model2(64, 128)
self.layer3 = custom_model2(128, 256)
self.layer4 = custom_model2(256, 512, stride=1)
self.inst_norm = nn.InstanceNorm2d(512)
self.conv = spectral_norm(nn.Conv2d(512, 1, kernel_size=(4,4), padding=1))
def forward(self, img, seg):
x = torch.cat((seg, img.detach()), dim=1)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = leaky_relu(self.inst_norm(x))
x = self.conv(x)
return x
Loss function
The most important piece for training a GAN. We are all familiar with the loss function of minimizing the Generator and maximizing the discriminator, where the objective function looks something like this.
$$\mathbb{E}_{(\boldsymbol{\mathrm{s}},\boldsymbol{\mathrm{x}})}[\log D(\boldsymbol{\mathrm{s}},\boldsymbol{\mathrm{x}})]+\mathbb{E}_{\boldsymbol{\mathrm{s}}}[\log (1-D(\boldsymbol{\mathrm{s}},G(\boldsymbol{\mathrm{s}})$$
Now we extend this loss function to a feature matching loss. What do I mean? When we compute this loss function we are only computing the values on a fixed size of the image, but what if we compute the losses at different sizes of the image and then sum them all.
This loss would stabilize training as the generator has to produce natural statistics at multiple scales. To do so, we extract features from multiple layers of the discriminator and learn to match these intermediate representations from the real and the synthesized images. This is done by taking features out of a pretrained VGG model. This is called perceptual loss. The code makes it easier to understand.
class VGGLoss(nn.Module):
def __init__(self):
super().__init__()
self.vgg = VGG19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
So we take the two images, real and synthesized and pass it through VGG network. We compare the intermediate feature maps to compute the loss. We can also use ResNet, but VGG works pretty good and earlier layers of VGG are generally good at extracting the features of an image.
This is not the complete loss function. Below I show my implementation without the perceptual loss. I strongly recommend seeing the loss function implementation used by Nvidia themselves for this project as it combines the above loss also and it would also provide a general guideline on how to train GANs in 2019.
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super().__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.L1Loss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = torch.tensor(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = torch.tensor(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor.to(torch.device('cuda')))
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def conv_inst_lrelu(in_chan, out_chan):
return nn.Sequential(
nn.Conv2d(in_chan, out_chan, kernel_size=(3,3), stride=2, bias=False, padding=1),
nn.InstanceNorm2d(out_chan),
nn.LeakyReLU(inplace=True)
)
class SPADEEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.layer1 = conv_inst_lrelu(3, 64)
self.layer2 = conv_inst_lrelu(64, 128)
self.layer3 = conv_inst_lrelu(128, 256)
self.layer4 = conv_inst_lrelu(256, 512)
self.layer5 = conv_inst_lrelu(512, 512)
self.layer6 = conv_inst_lrelu(512, 512)
self.linear_mean = nn.Linear(8192, args.gen_input_size)
self.linear_var = nn.Linear(8192, args.gen_input_size)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = x.view(x.size(0), -1)
return self.linear_mean(x), self.linear_var(x)
Why Spectral Normalization?
Spectral Normalization Explained by Christian Cosgrove. This article discusses spectral norm in detail with all the maths behind it. Ian Goodfellow even commented on spectral normalization and considers it to be an important tool.
The reason we need spectral norm is that when we are generating images, it can become a problem to train our model to generate images of say 1000 categories on ImageNet. Spectral Norm helps by stabilizing the training of discriminator. There are theoretical justifications behind this, on why this should be done, but all that is beautifully explained in the above blog post that I linked to.
To use spectral norm in your model, just apply spectral_norm
to all convolutional layers in your generator and discriminator.
Batch Normalization uses the complete batch to compute the mean and std and then normalizes the complete batch with a single value of mean and std. This is good when we are doing classification, but when we are generating images, we want to keep the normalization of these images independent.
One simple reason for that is if in my batch one image is being generated for blue sky and in another image, generating a road then clearly normalizing these with the same mean and std would add extra noise to the images, which would make training worse. So instance norm is used instead of batch normalization here.