Breakdown of PyTorch’s CNN Tutorial

Sean Yi
Analytics Vidhya
Published in
5 min readOct 5, 2019

This is an article that I’ll be writing down what I learned while going through the (very short) convolutional neural network (CNN) tutorial offered by PyTorch.

Hey all. My name’s Sean, and I’m currently a Master’s student studying computer science at Korea University. Specifically, my research is focused on deep learning methods.

I’ve been trying to study some PyTorch, as I need to quickly become familiar with it in order to conduct research and experiments for projects I’m involved in, and what better place than to start out at the official tutorials provided by the documentation? It’s a very short and simple model, but I actually learned a lot by going over it line by line.

Basic Info

Data

The tutorial basically walks us through using a CNN to perform classification on CIFAR-10 image samples. The CIFAR-10 dataset is a dataset that consists of 60,000 labeled images with 10 total classes. Each image is of shape (height=32, width=32, channels=3), and there are 50,000 training images and 10,000 test images.

The dataset is divided into 5 training batches with 10,000 images each, and one test batch.

Model Architecture

The overall model architecture is as follows:

Step by step:

  1. We receive the input image.
  2. We pass the image through a convolutional layer that has a kernel of size (5, 5) and has 6 filters and stride 1.
  3. We perform max pooling with shape (2, 2) and stride 2.
  4. We perform another convolution, but this time with 16 filters.
  5. We perform another max pooling operation.
  6. The image is flattened out to be of shape (16 * 5 * 5 = 400) and passed through the first fully connected layer.
  7. The image is again passed through two more subsequent fully connected layers.
  8. The final output is of shape (1, 10) for each image class, and we perform classification.

Please take a look at the website that I used to draw the network if you’re curious!

The Code

The entire code is as follows:

Now, let’s break this down part by part.

Step 1: Basic configuration for the data.

This step is relatively simple. We’re simply downloading the CIFAR-10 dataset by using the torchvision library’s functions, and then performing some basic preprocessing (e.g. normalization) on the data. The variables trainloader and testloader may seem confusing if you’re new to them (they were confusing to me), but they’re simply just tools that we use to load the data. This will become clear if you look at the second for loop in Step 5.

Step 2: Visualizing some samples.

The reason I skipped this is because the server that I’m using didn’t allow visualization for some reason. It wasn’t a big deal, though, because the important part of this tutorial isn’t seeing what the data looks like.

Step 3: Defining the CNN to be used.

If you look at the Net class, you’ll notice that we’ve basically defined the building blocks of the CNN we’ll use, using PyTorch’s built-in functions and methods. The forward function is basically what’s performing the actual calculations.

If you’re confused about how we came up with the specific numbers that are being defined in the functions, just remember that:

Recall the shapes of our data (starting from (32, 32, 3)), the shape of our kernel (5x5), we don’t use any padding, and we use a stride of 1 (for pooling layers we use stride of 2). Also keep in mind that pooling layers are also a convolutional layer, so the above equation also applies to them.

Anyway, the shape of the data changes as follows as we pass it through the network:

  1. (32, 32, 3): Input Data
  2. (28, 28, 6): Data after conv1
  3. (14, 14, 6): Data after pool
  4. (10, 10, 16): Data after conv2
  5. (5, 5, 16): Data after pool

In case you’re confused, here’s a diagram of the model architecture for refreshment:

Now, notice the last layer’s shape is (5, 5, 16) and that in the forward function within our Net class we flatten out our data with x.view(-1, 16 * 5 * 5) . If you’re not sure what the -1 is for, it’s meant to infer that particular dimension size. Check out the documentation for numpy.reshape() for a more detailed explanation.

If you’re not familiar with the view function, then I highly recommend checking out this Stack Overflow question. view is similar the reshape , but slightly different in a technical sense. It helps to conceptually think that they’re the same, though.

Step 4: Defining the optimizer and loss details.

This step isn’t that complicated. We’re simply defining the loss criterion we’ll be using, and the optimizer we’ll be using as well. As you can see, we’ll be using cross entropy loss and the stochastic gradient descent algorithm with momentum. If there’s something you’re not familiar with, please check out the documentation!

Step 5: Training the network.

This step is when we actually train the network. For each epoch, we’ll basically be:

  1. Extract the input data and corresponding labels inputs, labels = data .
  2. Clear out the gradient values in our optimizer optimizer.zero_grad() .
  3. Calculate the output scores with our model outputs = net(inputs) .
  4. Calculate the cross entropy loss between our predictions and the actual labels loss = criterion(outputs, labels) .
  5. Calculate the gradient for every parameter where require_grads=True using loss.backward() .
  6. Update our optimizer with the gradients computed from loss.backward() using optimizer.step() .
  7. Update our running_loss .
  8. Print the results if applicable.

Once you run this code, you’ll notice that the loss gradually decreases!

Conclusion

This tutorial doesn’t walk you through an entire deep learning pipeline, but I believe that it’s very helpful in learning how PyTorch works. One thing that you should keep in mind, though, is to make it a habit of looking at the documentation whenever you’re puzzled about a line of code.

My objective whenever I study is to always make sure that I can teach the material to someone else. It’s very time-consuming, but I really feel like it’s the right way to go.

I hope this tutorial was helpful to whoever has found it. Happy learning!

--

--

Sean Yi
Analytics Vidhya

Machine learning engineer currently working on NLP solutions in the fashion/e-commerce industry.