SPADE: State of the art in Image-to-Image Translation by Nvidia
Nvidia released a new paper Semantic Image Synthesis with Spatially-Adaptive Normalization. The official code can be found here(PyTorch), my implementation can be found here (PyTorch).
Nvidia has been pushing the state-of-the-art in GANs for quite some time now. Their earlier work pix2pixHD on which this paper pushed it even further. To give some motivation for this paper, see the demo released by Nvidia.
Table of Contents
- What is Semantic Image Synthesis:- Brief overview of the field.
- New things in the paper
- How to train my model?:- How Semantic Image Synthesis models work
- Then I dive into the different models that make up the SPADE project namely SPADE, SPADEResblk. Then I introduce Generator and Discriminator Models and the Encoder model for style transfer.
- Loss function is discussed in some detail and perceptual loss is also introduced with code. The original Nvidia code for loss function can be found here.
- There is a discussion on how to resize segmentation maps and how to initialize my model using He. initialization also.
- What is spectral normalization?:- When to use this normalization and a discussion on instance normalization.
What is Semantic Image Synthesis?
It is the opposite of image segmentation. Here we take a segmentation map (seg map)and our aim is to produce a colored picture for that segmentation map. In segmentation tasks, each color value in the seg map corresponds to a particular class.
New things in the paper
SPADE paper introduces a new normalization technique called spatially-adaptive normalization. Earlier models used the seg map only at the input layer but as seg map was only available in one layer the information contained in the seg map washed away in the deeper layers. SPADE solves this problem. In SPADE, we give seg map as input to all the intermediate layers.
How to train my model?
Before getting into the details of the model, I would discuss how models are trained for a task like Semantic Image Synthesis.
The core idea behind the model training is a GAN. Why GAN is needed? Because whenever we want to generate something that looks photorealistic or more technically closer to the output images, we have to use GANs.
So for GAN we need three things 1) Generator 2) Discriminator 3) Loss Function. For the Generator, we need to input some random values. Now you can either take random normal values. But if you want your output image to resemble some other image i.e. take the style of some image and add it your output image, you will also need an image encoder which would provide the mean and variance values for the random Gaussian distribution.
For the loss function, we would use the loss function used in pix2pixHD paper with some modifications. Also, I would discuss this technique where we extract features from the VGG model and then compute loss function (perceptual loss).
This is the basic block that we would use.
How to resize seg map?
Every pixel value in your seg map corresponds to a class and you cannot introduce new pixel values. When we use the defaults in various libraries for resizing, we do some form of interpolation like linear, which can change up the pixel values and result in values that were not there before. To solve this problem, whenever you have to resize your segmentation map use ‘nearest’ as the upsampling or downsampling method.
How we use it? Consider some layer in your model, you want to add the information from the segmentation map to the output of that layer. That will be done using SPADE.
SPADE first resizes your seg map to match the size of the features and then we apply a conv layer to the resized seg map to extract the features. To normalize our feature map, we first normalize our feature map using BatchNorm and then denormalize using the values we get from the seg map.
Below I present my PyTorch implementation of the model. You should check the official implementation for the model also.
Just like Resnet where we combine conv layers into a ResNet Block, we combine SPADE into a SPADEResBlk.
The idea is simple we are just extending the ResNet block. The skip-connection is important as it allows for training of deeper networks and we do not have to suffer from problems of vanishing gradients.
Now we have our basic blocks, we start coding up our GAN. Again, the three things that we need for GAN 1)Generator 2)Discriminator 3)Loss Function
The most important piece for training a GAN. We are all familiar with the loss function of minimizing the Generator and maximizing the discriminator, where the objective function looks something like this.
Now we extend this loss function to a feature matching loss. What do I mean? When we compute this loss function we are only computing the values on a fixed size of the image, but what if we compute the losses at different sizes of the image and then sum them all.
This loss would stabilize training as the generator has to produce natural statistics at multiple scales. To do so, we extract features from multiple layers of the discriminator and learn to match these intermediate representations from the real and the synthesized images. This is done by taking features out of a pretrained VGG model. This is called perceptual loss. The code makes it easier to understand.
So we take the two images, real and synthesized and pass it through VGG network. We compare the intermediate feature maps to compute the loss. We can also use ResNet, but VGG works pretty good and earlier layers of VGG are generally good at extracting the features of an image.
This is not the complete loss function. Below I show my implementation without the perceptual loss. I strongly recommend seeing the loss function implementation used by Nvidia themselves for this project as it combines the above loss also and it would also provide a general guideline on how to train GANs in 2019.
In the paper, they used Glorot Initialization (another name of Xavier initialization). I prefer to use He. Initialization
This is the final part of our model. It is used if you want to transfer style from one image to the output of SPADE. It works by outputting the mean and variance values from which we compute the random gaussian noise that we input to the generator.
Why Spectral Normalization?
Spectral Normalization Explained by Christian Cosgrove This article discusses spectral norm in detail with all the maths behind it. Ian Goodfellow even commented on spectral normalization and considers it to be an important tool.
The reason we need spectral norm is that when we are generating images, it can become a problem to train our model to generate images of say 1000 categories on ImageNet. Spectral Norm helps by stabilizing the training of discriminator. There are theoretical justifications behind this, on why this should be done, but all that is beautifully explained in the above blog post that I linked to.
To use spectral norm in your model, just apply
spectral_norm to all convolutional layers in your generator and discriminator.
Brief Discussion on Instance normalization
Batch Normalization uses the complete batch to compute the mean and std and then normalizes the complete batch with a single value of mean and std. This is good when we are doing classification, but when we are generating images, we want to keep the normalization of these images independent.
One simple reason for that is if in my batch one image is being generated for blue sky and in another image, generating a road then clearly normalizing these with the same mean and std would add extra noise to the images, which would make training worse. So instance norm is used instead of batch normalization here.
- SPADE Paper (link)
- Official Implementation (link)
- My Implementation (link)
- pix2pixHD (link)
- Spectral Normalization paper (link)
- Spectral Norm blog (link)
- Instance Normalization paper (link)
- Instance norm other resources, blog, stack overflow
- State of art methods for training neural networks in 2019. Working on this for quite some time.
- childNN:- Fun project write neural network on paper, take its picture and it would train the model. (more on this next week)