A Look Under the Hood of Pytorch’s Recurrent Neural Network Module

A Guide to Weight initializations and Matrix Multiplications inside Pytorch’s RNN module.

Maximinusjoshus
Geek Culture
9 min readJun 19, 2022

--

Photo by Alfons Morales on Unsplash

Neural networks have been doing a magnificent job in mimicking the human brain to recognize objects, segment images and even to interact with people. Speaking of interaction, decades of research has gone into the field of Natural Language Processing (a subfield of Artificial Intelligence which is concerned with interaction between computers and human beings) to make the process of interacting with computers, a more seamless experience. The state-of-the-art systems used today for Natural language processing are largely based on one neural network architecture, the Recurrent Neural Network (RNN).

In this article, we will be skimming over the internal architecture of a Vanilla Recurrent Neural Network and take a deep dive into two internal mechanisms: Weight Initialization and Matrix Multiplications in an RNN

What makes RNNs so special?

Before RNNs, Feed Forward Networks were used for solving NLP tasks. The network was fed with text data encoded into integers and trained through the traditional adjusting of weights and biases through backpropagation. Though they produced considerable results, they lacked the ability to understand the contextual meaning of the input data.

On the other hand, RNNs have the inherent ability to memorize the sequential meaning of the input data. Thus when they make predictions, (for example, in a sequence prediction problem) the prediction is based on the entire sequence, fed into the model and not just the last character or word of the input sequence.

A Look Into the Internal Architecture:

The basic architecture of a Recurrent Neural Network looks like this:

This is a vague representation of an RNN. The entire process would be more comprehensible if the recursive part is unfolded like it is in the image below.

When an input is fed into the RNN, a hidden state is calculated. This hidden state is fed back into the RNN cell, which simultaneously takes in the next input. Now this hidden state along with the new input is use to calculate a new hidden state which is fed back again into the RNN cell. This cycle continues for a designated number of times till it gives the required output. Each recursive step which produces a hidden state in an RNN is called a time step.

For example, consider these sentences: “Alicia loves pets. She has adopted a dog named Oliver. She always carries a shoulder bag with her, wherever she goes. When Alicia went on a vacation, she was unable to carry her pet ____ with her”. Here, the neural network should be able to remember information from the second sentence, to fill the blank in the fourth sentence. So when these sentences are fed word by word into the RNN, it will remember the sequential information obtained at each time step with the help of hidden states.

RNNs have different variants, each of which is used for different use cases.

As you can see in the image above, the number of inputs and outputs of an RNN can vary according to specific needs. Andrej Karpathy in his blog The Unreasonable Effectiveness of Recurrent Neural Networks gives these examples for the RNNs seen above. From left to right: (1) Vanilla mode of processing without RNN, from fixed-sized input to fixed-sized output (e.g. image classification). (2) Sequence output (e.g. image captioning takes an image and outputs a sentence of words). (3) Sequence input (e.g. sentiment analysis where a given sentence is classified as expressing positive or negative sentiment). (4) Sequence input and sequence output (e.g. Machine Translation: an RNN reads a sentence in English and then outputs a sentence in French). (5) Synced sequence input and output (e.g. video classification where we wish to label each frame of the video).

The Math

Sequence prediction using an RNN can be elucidated in these five steps.

  • An input is fed into the RNN cell (the first time step)
  • The RNN cell calculates the hidden state for this time step.
  • This hidden state is fed back into the RNN cell (the second time step), and simultaneously the RNN takes in the next sequential input.
  • The hidden state of the first time step and the new input are multiplied with their respective weight matrices and summed.
  • An activation function is applied to the sum obtained to produce the hidden state for this time step. This hidden state is fed back into the RNN cell and the cycle repeats until the required output is produced.

The hidden state at each time step is given by the formula below

At each time step, the output hidden state is basically the sum of the hidden state of the RNN’s previous time step and it’s new sequential input, multiplied with their respective weight matrices and passed through an activation function.

These hidden states can also be used to get the output of the RNN at each time step by multiplying them with an output weight matrix.

Weight Initialization and Matrix Multiplications

All we have seen till now is in theory. But it’s nearly impossible to understand a network thoroughly without getting out hands dirty with code. This wonderful repository has a full code walkthrough explaining how to build an RNN from scratch using Pytorch. You can also follow this blogpost to get a better intuition of the code.

This article contains no code demonstration as the above mentioned blog has already done a great job in explaining the code and redoing it in this post would make it redundant.

But there are some internal initialization settings and mechanics in Pytorch that one should be aware of. Before getting into the details, six variables should be defined.

input- the input features
w_ih- weight matrix for the input
b_ih- bias for the input
h0- the initial hidden state
w_hh- weight matrix for the hidden state
b_hh- bias for the hidden state

Weight initialization

Consider the following example. An rnn object is initialized for the RNN model class. The input and the initial hidden state is defined and passed into the RNN.

The constructor arguments for the RNN class are (input features, hidden dimensions, number of RNN layers). We will use a single RNN layer in our example.

The shape of the input should be of the format (batch size, length of the input sequence, number of input features). The batch size should be the first dimension as we have set the parameter batch_first = True while initializing the rnn object.

The shape of the initial hidden state should be of the format (number of RNN layers, batch size, hidden dimensions).

Now as per our RNN formula, at each time step the hidden state is calculated by,

We pass in the input and initial hidden state into our RNN model. The weights for each of them is initialized and inserted in the formula above to calculate the next hidden state. These calculations are done under the hood, within the nn module of Pytorch. Let us see how Pytorch initializes the weight matrices for the input and hidden states.

This is a screenshot from RNN.py (Line 84 to 94) from Pytorch’s official github repository

This is the part where shape of the weights are defined. In the third and fourth lines, two variables namely the real_hidden_size and the layer_input_size have been defined.

As we are constructing a vanilla RNN, we don’t have any projection sizes. So the real_hidden_size will be equal to the hidden_size (number of hidden dimensions), which is equal to 20 in our example.

We have one RNN layer in our model. As we can see in the code screenshot above, the weight shapes are initialized by looping over the number of layers in the model. So in the first iteration, layer will be equal to 0. So layer_input_size will be equal to the input size, which will be 10 in our example

Now for initializing the shape of the input weight, we can see that an empty Pytorch tensor has been defined with shape (gate_size, layer_input_size). The gate size is initialized in Lines 71–80 of RNN.py.

As we are using tanh as out activation function, gate_size will be equal to hidden_size which is 20 in our example.

Finally the shapes of the weights and biases of our input and hidden states will be:

Now as we have the shapes initialized, it is time to assign initial values to these matrices. The weights for these RNN parameters are initialized from an uniform distribution which ranges from -k to k where k = 1/sqrt(hidden_size). This is defined in the reset_parameters function of RNN.py (Lines 193–196)

Okay that looks like we’ve achieved a milestone. But it would be better if we also understood the matrix multiplications occurring under the hood of Pytorch.

The Calculations

Our input is of the shape (3, 5, 10). 3 is the batch size- our example has 3 batches of data. So when we pass in the first batch, the input shape would be (1, 5, 20), which can also be considered as (5,20) . The weight matrix for the input (w_ih) is of the shape (20,10). For two matrices to be multipliable, the number of columns (dim 1) in matrix a should be equal to the number of rows in matrix b (dim 0). For that purpose, we transpose the weight matrix (w_ih). So now the shape of w_ih will be (10, 20) and we can multiply the matrices.

The calculations for the weighted hidden state also follows the same sequence of events. The hidden state is of the shape (1, 3, 20) where 3 is the batch_size. So for each batch the hidden state would be of the shape (1, 1, 20) which can also be considered as (1, 20). The weight matrix for the hidden state is of the shape (20, 20). Here, there is no need for transposing, as the matrices satisfy the multiplication rules. But we need to be consistent with transposing all the matrices, if we had done it to one. So w_hh is transposed here too.

Now the weighted input and the weighted hidden state can be added and the tanh activation can be applied to it to get our new hidden state. But wait a minute.. the weighted input and the weighted hidden state are of different shapes. For two matrices to be added, they should be of the same shape!

To add these matrices, Pytorch uses a technique called Broadcasting. Our weighted hidden state has 20 elements and our weighted input has 5X20 = 100 elements, with 20 elements in each column (dim 1). With Broadcasting, the weighted hidden state matrix is added with each column of the weighted input matrix. You can learn more about Broadcasting here.

And thus we get out new hidden state which will be fed back into the RNN as the hidden state for the RNN’s next time step.

We haven’t discussed biases in our calculations for the sake of simplicity. They are one dimensional matrices with the shape = number of hidden dimensions. They are added with the weighted input and weighted hidden state at each time step using Broadcasting.

Thank You

And that’s a wrap! This is how Pytorch initializes weights and multiplies it’s matrices for RNN within it’s core.

--

--