Lesson 7 of 7: Resnet, UNet, GANs

Notes from Practical Deep Learning for Coders 2019 Lesson 7 (Part 1)

Julia Wu
Julia Wu
Aug 18, 2019 · 13 min read

Other lessons: Lesson 1 / Lesson 2 / Lesson 3 / Lesson 4 / Lesson 5 / Lesson 6

Quick links: course page / Lecture / Jupyter Notebooks



By now, we should be pretty familiar with the process of loading in image data and creating a DataBlock ( likeImageList):

  1. Specify the path of the image data
  2. Load in the data

Inside an items list il is the image you gave it, so you can index into the list and view the image content.

3. Split into training and validation.

Remember you can specify how you want to split it, including no_split

In this case we’re splitting by folder. The validation folder is testing

In Kaggle, validation set has labels, test set has no labels

4. Assign labels

As expected, x is images, y is category

If we index into the contents, we see that x is the actual image and y is the label:

5. Apply transformations

We’re just going to add some padding. For small images like these, random padding does the job

6. Pick a batchsize, create databunch

7. Inspect the prepared data

The training dataset now has data augmentation from our transforms.

This _plot function is taking an element from the training set from disk and transform it on the fly. plot_multi is a function that will plot the result of calling some function for each of this row by column grid.

Use one_batch() to select a batch of training data

Basic CNN with batchnorm

Jeremy creates a function to consistently return the same convolutional network. Always size 3 kernel with stride of 2

Details on the model:

Each time you have a convolution, it’s skipping over one pixel so it’s jumping two steps each time. It’s going to halve the grid size.

After the first convolution, we have an 8x14x14 tensor.

Then we’ll do a batch norm, then we’ll do ReLU.

The number of input filters to the next conv has to equal the number of output filters from the previous conv, and we can just keep increasing the number of channels

And because we’re doing stride 2, it’s got to keep decreasing the grid size

Always 3 procedures: conv, bachnorm, ReLU and repeat

When we’re down to 1x1, we have a feature map of 10x1x1. It’s a rank 3 tensor with those dimensions.

Our loss functions generally expect a vector, not a rank 3 tensor, so you can use flatten() to remove any unit axes.

Now, we put this CNN model into a Learner along with the data, loss function and metrics to print out

Taking a look at the learner with learn.summary():

Notice the output shapes: 14x14, 7x7, 4x4, 2x2

Earlier, we had created a minibatch of xb

We can pop it onto the gpu with xb.cuda(), and pass in the batch to the model and get the shape of the input

Next step is also familiar with the lr_find() and learn.recorder.plot()

fit_one_cycle() is training from scratch and we're already at pretty good accuracy:


Rather than performing the 3 steps conv, batch norm, ReLU all the time, already has something called conv_layer() which lets you create conv, batch norm, ReLU combinations.

So for the new CNN factory function, we can use conv_layer

Model definition:

Again, set the learner and let it fit_one_cycle:

Accuracy is pretty good:


How can we improve this? We want to create a deeper network. After every stride 2 conv, add a stride 1 conv. The stride 1 conv doesn’t change the feature map size at all, so you can add as many as you like.

But the problem is that deeper networks can have higher training and test errors — discussed in famous paper Deep Residual Learning for Image Recognition

So instead of output = conv2(conv1(x)), we want output = x + conv2(conv1(x))

This led Kaiming He and his team that worked on that paper to easily win ImageNet that year.

This is the concept of a ResBlock. We can initialize a ResBlock layer in our code as well:

We create a nn.Module, two conv layers (a conv_layer is Conv2d, ReLU, batch norm), so create two of those and then in forward() we go conv1(x), then conv2() on that and then add x.

There’s a res_block() function in fastai so you can also just use that

Move that into a function conf_and_res()


Update the learner:

Next, use lr_find() and fit_one_cycle() as usual

Running fit_one_cycle, we can see that the accuracy is pretty good (99.45%), and we literally just trained this from scratch


From this article:

The main idea behind CNN is to learn the feature mapping of an image and exploit it to make more nuanced feature mapping. While converting an image into a vector, we already learned the feature mapping of the image so why not use the same mapping to convert it again to image. This is the recipe behind UNet.

Great visual guide for convolutions: has an implementation of the U-Net. The key thing that comes in is the encoder. The encoder refers to the downsampling part of U-Net, which in our case is a ResNet 34.

So our layers of our U-Net is an encoder, then batch norm, then ReLU, and then middle_conv which is just (conv_layer, conv_layer).

Remember, conv_layer is a conv, ReLU, batch norm in So that middle con is these two extra steps here at the bottom:

Image Restoration


We’re going to start with some low-resolution, poor-quality JPGs. We’ll start with nice images and crappify them:

The implementation for the crappifier:

Now we’re going to use a U-net that takes a crappy image and makes it better

Pre-train generator

We have a function for transforming the data and turning it into a databunch:

Use this function and inspect the data:

specify a weight decay wd, y_range for the output and a loss function loss_gen

Next, we create a generative learner with the parameters above

Initialize it and fit_one_cycle()

Remember, the pre-trained path of a u-net is the part of the U shape that goes downward

Let’s unfreeze that pre-trained path and train a little more:

We save this model with'gen-pre2')

To do better, we can use a loss function that does better than pixel MSE loss.

There’s a very general way of answering that question — using a Generative Adversarial Network (GAN). GANs use a loss function that calls another model:

We now have a hi-res image, and we compare that image with the pixel MSE.

We can also train another model called the Discriminator/Critic that takes the generated images and tries to classify which is which. A binary classification model that takes in the generated image and the hi-res image.

We can fine-tune the generator. Instead of using pixel MSE as the loss, the loss will now be “how good are we at fooling the critic?” Can we create generated images that the critic thinks are real? The model is going to learn to create images that the critic can’t tell they’re fake. But at some point the generator is going to get pretty good, so we need to boost the critic’s ability to discriminate by training some more on these newly generated images.

And we’ll improve the generator based on the better critic. And improve the critic with better generated images. Ping pong back and forth.

Training the critic

First, we load and save the generated images.

We create a function that takes the filenames

Train critic

A way to make python do garbage collection to avoid restarting notebook due to lack of memory

Let’s pretrain the critic on crappy vs not crappy. It’s going to be an ImageList from a folder just as usual, and the classes will be image_gen and images

e do a random split because we want to know how well we’re doing with the critic to have a validation set.

Label it from folder in the usual way, apply some transformations, call databunch() with normalized() as usual

Totally standard classifier. Initialize and inspect:

We’re going to use binary cross entropy (BCE) as usual:

However, we’re not going to use a ResNet here. We have to use something called spectral normalization by calling’s gan_critic() to make GANs work.

Initialize the learner and fit:

Save it with‘critic-pre2’)


We can combine pretrained model in a GAN

Let’s initialize the critic:

Generative learner:

To define a GAN learner, we specify the learner objects for the generator and critic

This AdaptiveGANSwitcher is a callback that decides when to switch from discriminator to generator and vice versa.

We do as many iterations of the discriminator as needed to get its loss back to < 0.5, then one iteration of the generator

The loss of the generator is the weighted sum (weights in weights_gen) of learn_crit.loss_func on the batch of fakes (passed through the critic to become predictions) with a target of 1

The learn_gen.loss_func applied to the output (batch of fakes) and the target (corresponding batch of hi-res images)

GANs don’t like momentum so we leave that to 0.

The loss numbers are meaningless because when the critic gets better, it’s harder for the generator. And when the generator gets better, it’s harder for the critic. So the loss numbers should stay about the same. It’s hard to know how they’re doing.

One way to tell how they’re doing is to just look at the results.

Super resolution

To do better, we’re again going to take a step further by finding a more interesting loss function. Perceptual losses but preferred name is “feature losses”, so in the library, you’ll see this referred to as feature losses.

It shares something with GANs which is what paper calls the “image transform net”. It’s kind of like the U-Net in that it has this u-shape. The downsampling path is the encoder, and the upsampling path is the decoder.

The loss function asks, “is this thing that was generated like the thing that we want?”

We put the prediction y_hat in a pretrained imagenet. The pretrained ImageNet used to be called VGG. In the process of getting to that final classification, the prediction goes through a lot of layers (color-coded).

They look at the prediction at the layers and pick out some activations in the middle. So the activations might be feature maps of 256 channels of 28x28.

And then we then take the target (i.e. the actual y value) and we put it through the same pre-trained VGG network, and we pull out the activations of the same layer.

Then we do a mean square error comparison. So it’ll say “in the real image, grid cell (1, 1) of that 28 by 28 feature map is furry and blue and round shaped. And in the generated image, it’s furry and blue and not round shape.” So it’s an okay match.

That ought to go a long way towards fixing our eyeball problem, because in this case, the feature map is going to say “there’s eyeballs here (in the target), but there isn’t here (in the generated version), so do a better job of that please. Make better eyeballs.” So that’s the idea. That’s what we call feature losses what or Johnson et al. called perceptual losses.

We’ll apply that in lesson7-superres.ipynb. Note from Jeremy:

“I wrote this notebook a little bit before the GAN notebook — before I came up with the idea of like putting text on it and having a random JPEG quality, so the JPEG quality is always 60, there’s no text written on top, and it’s 96 by 96. And before I realized what a great word “crappify” is, so it’s called resize_one."

So our goal is to create a loss function which does what we just described above.

Perform the initial steps of getting the data (ImageList, resize_one, etc), and use show_batch()

Recurrent Neural Networks (RNNs)

Square means an input. It’s batch_size * number of inputs

An arrow is a layer, a circle is activations. Output will be batch_size * activations

Another arrow which means another layer. The output of this layer is softmax

Final output (triangle) will be batch_size * classes

Each time we go from rectangle to circle, we’re doing an embedding. Which is just a particular kind of matrix multiply where you have a one hot encoded input.

Each time we go from circle to circle, we’re basically taking one piece of hidden state (a.k.a activations) and turning it into another set of activations by saying we’re now at the next word

Then when we go from circle to triangle, we’re doing something else again which is we’re saying let’s convert the hidden state (i.e. these activations) into an output.

Julia Wu

Written by

Julia Wu

Engineer thinking about fintech, AI, China, and our civilization | @Apple, @Microsoft, @BrownUniversity

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade