All you need for Photorealistic Style Transfer in PyTorch
A quick summary of probabilistic math used in machine learning.
- What is style transfer?
- Why another paper?
- Gram Matrix
- High-Resolution Models
- Style transfer details
- Hi-Res Generation Network
- Implementation code
- Loss functions
- Difficult part
- Conclusion
Link to jupyter notebook, paper
Our aim is to transfer the style from style image to the content image. This looks something like this.
Why another paper?
Earlier work on style transfer although successful was not able to maintain the structure of the content image. For instance, see Fig2 and then see the original content image in Fig1. As you can see the curves and structure of the content image are not distorted and the output image has the same structure as content image.
Gram Matrix
The main idea behind the paper is using Gram Matrix for style transfer. It was shown in these 2 papers that Gram Matrix in feature map of convolutional neural network (CNN) can represent the style of an image and propose the neural style transfer algorithm for image stylization.
- Texture Synthesis Using Convolution Neural Networks by Gatys et al. 2015
- Image Style Transfer Using Convolutional Neural Networks by Gatys et al. 2016
Details about gram matrix can be found on wikipedia. Mathematically, given a vector V gram matrix is computed as $$G=V^TV$$
High-Resolution Models
It is a recent research paper accepted at CVPR 2019 paper. So generally what happens in CNNs is we first decrease the image size while increasing the number of filters and then increase the size of the image back to the original size.
Now this forces our model to generate output images from a very small resolution and this results in loss of finer details and structure. To counter this fact High-Res model was introduced.
High-resolution network is designed to maintain high-resolution representations through the whole process and continuously receive information from low-resolution networks. So we train our models on the original resolution.
Example of this model would be covered below. You can refer to the original papers for more details on this. I will cover this topic in detail in my next week blog post.
There are three things that style transfer model needs
- Generating model:- It would generate the output images. In Fig4 this is ‘Hi-Res Generation Network’
- Loss function:- Correct choice of loss functions is very important in case you want to achieve good results.
- Loss Network:- You need a CNN model that is pretrained and can extract good features from the images. In our case, it is VGG19 pretrained on ImageNet.
So we load VGG model. The complete code is available at my GitHub repo.
if torch.cuda.is_available():
device = torch.device('cuda')
else:
raise Exception('GPU is not available')
# Load VGG19 features. We do not need the last linear layers,
# only CNN layers are needed
vgg = vgg19(pretrained=True).features
vgg = vgg.to(device)
# We don't want to train VGG
for param in vgg.parameters():
param.requires_grad_(False)
torch.backends.cudnn.benchmark = True
Next we load our images from disk. My images are stored as src/imgs/content.png and src/imgs/style.png.
content_img = load_image(os.path.join(args.img_root, args.content_img), size=500)
content_img = content_img.to(device)
style_img = load_image(os.path.join(args.img_root, args.style_img))
style_img = style_img.to(device)
# Show content and style image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20,10))
ax1.imshow(im_convert(content_img))
ax2.imshow(im_convert(style_img))
plt.show()
# Utility functions
def im_convert(img):
"""
Convert img from pytorch tensor to numpy array, so we can plot it.
It follows the standard method of denormalizing the img and clipping
the outputs
Input:
img :- (batch, channel, height, width)
Output:
img :- (height, width, channel)
"""
img = img.to('cpu').clone().detach()
img = img.numpy().squeeze(0)
img = img.transpose(1, 2, 0)
img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
img = img.clip(0, 1)
return img
def load_image(path, size=None):
"""
Resize img to size, size should be int and also normalize the
image using imagenet_stats
"""
img = Image.open(path)
if size is not None:
img = img.resize((size, size))
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
img = transform(img).unsqueeze(0)
return img
Detail:- When we load our images, what sizes should we use? Your content image size should be divisible by 4, as our model would downsample images 2 times. For style images, do not resize them. Use their original resolution. Size of content image is (500x500x3) and size of style image is (800x800x3).
The model is quite simple we start with 500x500x3 images and maintain this resolution for the complete model. We downsample to 250x250 and 125x125 and then fuse these back together with 500x500 images.
Details:-
- No pooling is used (as pooling causes loss of information). Instead strided convolution (i.e. stride=2) are used.
- No dropout is used. But if you need regularization you can use weight decay.
- 3x3 conv kernels are used everywhere with padding=1.
- Zero padding is only used. Reflex padding was tested but the results were not good.
- For upsampling,’bilinear’ mode is used.
- For downsampling, conv layers are used.
- InstanceNorm is used.
# Downsampling function
def conv_down(in_c, out_c, stride=2):
return nn.Conv2d(in_c, out_c, kernel_size=3, stride=stride, padding=1)
# Upsampling function
def upsample(input, scale_factor):
return F.interpolate(input=input, scale_factor=scale_factor, mode='bilinear', align_corners=False)
# Helper class for BottleneckBlock
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super().__init__()
# We have to keep the size of images same, so choose padding accordingly
num_pad = int(np.floor(kernel_size / 2))
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=num_pad)
def forward(self, x):
return self.conv(x)
class BottleneckBlock(nn.Module):
"""
Bottleneck layer similar to resnet bottleneck layer. InstanceNorm is used
instead of BatchNorm because when we want to generate images, we normalize
all the images independently.
(In batch norm you compute mean and std over complete batch, while in instance
norm you compute mean and std for each image channel independently). The reason for
doing this is, the generated images are independent of each other, so we should
not normalize them using a common statistic.
If you confused about the bottleneck architecture refer to the official pytorch
resnet implementation and paper.
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
super().__init__()
self.in_c = in_channels
self.out_c = out_channels
self.identity_block = nn.Sequential(
ConvLayer(in_channels, out_channels//4, kernel_size=1, stride=1),
nn.InstanceNorm2d(out_channels//4),
nn.ReLU(),
ConvLayer(out_channels//4, out_channels//4, kernel_size, stride=stride),
nn.InstanceNorm2d(out_channels//4),
nn.ReLU(),
ConvLayer(out_channels//4, out_channels, kernel_size=1, stride=1),
nn.InstanceNorm2d(out_channels),
nn.ReLU(),
)
self.shortcut = nn.Sequential(
ConvLayer(in_channels, out_channels, 1, stride),
nn.InstanceNorm2d(out_channels),
)
def forward(self, x):
out = self.identity_block(x)
if self.in_c == self.out_c:
residual = x
else:
residual = self.shortcut(x)
out += residual
out = F.relu(out)
return out
Now we are ready to implement our style_transfer model, which we call HRNet (based on the paper). Use the Fig5 as reference.
class HRNet(nn.Module):
"""
For model reference see Figure 2 of the paper https://arxiv.org/pdf/1904.11617v1.pdf.
Naming convention used.
I refer to vertical layers as a single layer, so from left to right we have 8 layers
excluding the input image.
E.g. layer 1 contains the 500x500x16 block
layer 2 contains 500x500x32 and 250x250x32 blocks and so on
self.layer{x}_{y}:
x :- the layer number, as explained above
y :- the index number for that function starting from 1. So if layer 3 has two
downsample functions I write them as `downsample3_1`, `downsample3_2`
"""
def __init__(self):
super().__init__()
self.layer1_1 = BottleneckBlock(3, 16)
self.layer2_1 = BottleneckBlock(16, 32)
self.downsample2_1 = conv_down(16, 32)
self.layer3_1 = BottleneckBlock(32, 32)
self.layer3_2 = BottleneckBlock(32, 32)
self.downsample3_1 = conv_down(32, 32)
self.downsample3_2 = conv_down(32, 32, stride=4)
self.downsample3_3 = conv_down(32, 32)
self.layer4_1 = BottleneckBlock(64, 64)
self.layer5_1 = BottleneckBlock(192, 64)
self.layer6_1 = BottleneckBlock(64, 32)
self.layer7_1 = BottleneckBlock(32, 16)
self.layer8_1 = conv_down(16, 3, stride=1) # Needed conv layer so reused conv_down function
def forward(self, x):
map1_1 = self.layer1_1(x)
map2_1 = self.layer2_1(map1_1)
map2_2 = self.downsample2_1(map1_1)
map3_1 = torch.cat((self.layer3_1(map2_1), upsample(map2_2, 2)), 1)
map3_2 = torch.cat((self.downsample3_1(map2_1), self.layer3_2(map2_2)), 1)
map3_3 = torch.cat((self.downsample3_2(map2_1), self.downsample3_3(map2_2)), 1)
map4_1 = torch.cat((self.layer4_1(map3_1), upsample(map3_2, 2), upsample(map3_3, 4)), 1)
out = self.layer5_1(map4_1)
out = self.layer6_1(out)
out = self.layer7_1(out)
out = self.layer8_1(out)
return out
So you take the outputs from the conv layers. Like for the above fig, you can take the output from the second 3x3 conv 64 layer and then 3x3 conv 128.
To extract features from VGG we use the following code.
def get_features(img, model, layers=None):
"""
Use VGG19 to extract features from the intermediate layers.
"""
if layers is None:
layers = {
'0' : 'conv1_1', # style layer
'5' : 'conv2_1', # style layer
'10': 'conv3_1', # style layer
'19': 'conv4_1', # style layer
'28': 'conv5_1', # style layer
'21': 'conv4_2' # content layer
}
features = {}
x = img
for name, layer in model._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
We use 5 layers in total for feature extraction. Only conv4_2 is used as layer for content loss.
Refer to Fig4, we pass our output image from HRNet and the original content and style image through VGG.
There are two losses
- Content Loss
- Style Loss
Content Loss
Content image and the output image should have a similar feature representation as computed by loss network VGG. Because we are only changing the style without any changes to the structure of the image. For the content loss, we use Euclidean distance as shown by the formula
$$l_{content}^{\phi,j}(y,\hat{y})=\frac{1}{C_jJ_jW_j}\left\|\phi_j(\hat{y}=\phi_j(y)\right\|^2$$
$\phi_j$ means we are referring to the activations of the j-th layer of loss network. In code it looks like this.
style_net = HRNet().to(device)
target = style_net(content_img).to(device) target.requiresgrad(True)
target_features = get_features(target, vgg) content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
Style Loss
We use gram matrix for this. So style of an image is given by its gram matrix. Our aim is to make style of two images close, so we compute the difference of gram matrix of style image and output image and then take their Frobenius norm.
$$l_{style}^{\phi,j}(y,\hat{y})=\left\|G_j^{\phi}(y)-G_j^{\phi}(\hat{y})\right\|^2$$
def get_gram_matrix(img):
"""
Compute the gram matrix by converting to 2D tensor and doing dot product
img: (batch, channel, height, width)
"""
b, c, h, w = img.size()
img = img.view(b*c, h*w)
gram = torch.mm(img, img.t())
return gram
# There are 5 layers, and we compute style loss for each layer and sum them up
style_loss = 0
for layer in layers:
target_gram_matrix = get_gram_matrix(target_feature)
# we already computed gram matrix for our style image
style_gram_matrix = style_gram_matrixs[layer]
layer_style_loss = style_weights[layer] * torch.mean((target_gram_matrix - style_gram_matrix) ** 2)
b, c, h, w = target_feature.shape
style_loss += layer_style_loss / (c*h*w)
content_loss = content_weight * content_loss
style_loss = style_weight * style_loss
The difficulty comes in setting these values. If you want some desired output, then you would have to test different values before you get your desired result.
To build your own intuitions you can choose two images and try different range of values. I am working on providing like a summary of this. It will be available in my repo README.
Paper recommends content_weight = [50, 100] and style_weight = [1, 10].
Conclusion
Well, congratulation made it to the end. You can now implement style transfer. Now read the paper for more details on style transfer.
Check out my repo README, it will contain the complete instructions on how to use the code in the repo, along with complete steps on how to train your model.