Torch.NN: Walk Through & Commentary | Part I of II

Alexander Rofail
7 min readNov 11, 2021

While coding along with Charles Frye’s video on “working through ‘what is torch.nn really’” (here) I figured I would spin up my own blog post about it as it helps me retain info and might help some folks as supplemental information.

Tools: Pure Python (3.x+), NumPy, PyTorch, Weights and Biases.

Dataset: MNIST Digits.

Goal: build a neural network from scratch without using torch.nn to classify MNIST digits. This task is often deemed the “Hello World” of deep learning.

I really love the way Charles structures the code and notebook in that we really start from barebones python and import libraries only when it’s time to one. This allows us to very nicely see the progression of importing libraries to handle exactly what we need them to do when pure python won’t cut it.

At the most basic level all we need to create neural networks ourselves is an understanding and utilization of tensors and tensor math (dot products, matrix multiplies, and derivatives.)

This post will reiterate many of the same topics from my Linear Regression post but with some fancy PyTorch sprinkled in. Let’s start with a what a simple forward pass through a model looks like under the hood, we will use PyTorch to split our data into training and testing but then do the following from scratch: initialize parameters, define an activation function, define the model, define loss and accuracy functions.

Simple Forward Pass

Regardless of how complex the architecture or deep the network, there are a few key ingredients we need to be able to run a proper forward pass of data through a model.

Train/test split of our data, initialization of weights and biases, an activation function (for > 1 layer deep networks), a loss function, a means to track accuracy, and finally what the model will actually do.

Ingredients for a proper forward pass.

Now that we have the ingredients set, we can get cooking with a full on training loop. We will need a couple hyperparameters: learning rate and epochs. I was scared when I first learned about hyperparameters and didn’t know what they were — you can think of them as basically any extra variables that affect the parameters of our model (weights and biases).

NLL Function

Originally I was going to do a side note on what exactly the NLL loss function is doing and how it is tied to the softmax function (output layer). Well, that started to get real long and warrants a post of it’s own. So I decided it’s probably a better idea to just link to a post that is one of the most clear walkthroughs of softmax and NLL I have ever read.

I don’t know who Lj Miranda is, but they wrote this and I think it’s awesome: https://ljvmiranda921.github.io/notebook/2017/08/13/softmax-and-the-negative-log-likelihood/#nll

Training Loop: Logistic Regression

We have no hidden layers so this regression

The two hyperparameters we have here are learning rate and epochs. Learning rate controls the scale of how much we step in gradient descent and epochs control how many passes through the full training set we want to run through.

Obviously the above training loop is not 100% from scratch as we are using PyTorch’s built in functions to optimize our parameters (weights and biases). Now let’s upgrade our training loop by refactoring it with some of Pytorch’s Torch.nn module. In the spirit of Uncle Bob we want code that is concise, clear, and flexible.

Refactoring with NN.Functional

Let’s start by importing the functional module, this module contains all the functions we need for our training loop. We can start shortening our code by using the F.cross_entropy function that will combine our previously from-scratch soft_max() and nll() functions. The most important thing to note here is that we don’t want to use the cross_entropy function just because it is shorter, it is also more flexible and safer because the big brains at pytorch implanted this in such a way that it won’t break upon calculating extreme values.

Get’s us the same result as our from-scratch version.

Refactoring with NN.Module

Now we will start to dabble in some OOP concepts. NN.Module is a generic class with attributes and methods for general convenience needs across most neural network development. We will subclass nn.Module so that we can keep track of the state of our network.

Now that we have created a subclass of nn.Module for our MNIST logistic regression model, we have to instantiate a new instance of a model to proceed. The most stark difference between our former form and this new subclassed form is in optimizing our parameters. Previously we had to manually write out how to update the weights and biases. Now that we are keeping track of the state of our network, we can loop over the parameters and zero out the gradients much easier.

Just a little shorter now here in the torch.no_grad() block

Now let’s work on refactoring our class. Instead of manually defining and initializing our parameters and manually doing the matrix multiple and addition of bias, we can use Pytorch’s nn.Linear class.

Everything else remains the same, we can still call our fit() function and print the loss.

Optimizer

In the spirit of making our code shorter let’s use Pytorch’s torch.optim package that will automatically step through the gradients and update the parameters for us.

Now we have gone from five lines worth of optimizing our parameters, to three lines! We also created a nice little get_model() function in the name of flexibility. The cool part about this is that we can now use a variety of different optimizers if we just plug and play within that get_model() function, in our case we are still using Stochastic Gradient Descent.

Datasets

Another Pytorch class we want to use is the Datasets class. Key characteristics of a Dataset is that they have a __len__ and a __getitem__ function. The next phase of refactoring we will work through is stepping through the dataset. You can see above how we clunkily establish a starting and ending point to index through the input and targets of the training dataset. Pytorch’s TensorDataset class is a Dataset wrapping tensors here “each sample will be retrieved by indexing tensors along the first dimension” according to the docs. With TensorDataset we can combine our xb and yb variables above into a single line. This gives us a shorter and cleaner way to iterate/index through the first dimension of a tensor.

Our training loop is getting shorter which means it’s more readable which means if we need to explain it to someone or bring on another teammate to work with it will be easier for everyone to understand.

DataLoader

The final part of the training loop (notice we have been working from the bottom of the code upwards) is the batch management section. Instead of managing batches from scratch with arithmetic like we currently have in place, we can use Pytorch’s DataLoader class to iterate over batches of data. A Dataloader can be created from any dataset.

We have now officially refactored from a lengthy 12 line training loop all the way down to 7 lines, almost 50% shorter! Uncle Bob would be proud.

Summarizing Part I

  1. We used nn.Module to create a callable, with behavior similar to that of a function except that we can maintain a state, which is important for keeping track of model parameters (weights & biases).
  2. We used nn.Functional to easily pull in functions that are non-stateful for things like matrix multiplies, activation functions, and loss functions, rather than creating all these from scratch and clogging up our notebook.
  3. We used torch.optim to easily pull in optimizers like SGD for our backprop step.
  4. We used Dataset and DataLoader to easily iterate through our data and manage batch sizes/iterations.

Up Next

We will dive into a number of the features available on the Weights & Biases platform.

--

--

Alexander Rofail

Things I like: Startups, AI, Bow Hunting, Powerlifting, Martial Arts.