Understanding RNN implementation in PyTorch
RNNs and other recurrent variants like GRU, LSTMs are one of the most commonly used PyTorch modules. In this post, I go through the different parameters of the RNN
module and how it impacts the computation and resultant output.
The full notebook for this post is available at : https://github.com/rsk2327/DL-Experiments/blob/master/Understanding_RNNs.ipynb
Basic Example
RNN Hyperparameters
The key parameters for an RNN cell block are :
input_size
- Defines the number of features that define each element (time-stamp) of the input sequencehidden_size
- Defines the size of the hidden state. Therefore, ifhidden_size
is set as 4, then the hidden state at each time step is a vector of length 4num_layers
- Allows the user to construct stacked RNNs. The concept of stacked RNNs and how they work is explained laterbias
- Whether or not to include a bias term in RNN cellbidirectional
- Whether the RNN layer is bi-directional or notbatch_first
- Defines the input format. If True, then the input sequence has is in the format of (batch, sequence, features)
To keep things simple, for the basic example, we set input_size
, hidden_size
and num_layers
to be 1 and bidirectional
is set to False.
RNN output
The RNN module in PyTorch always returns 2 outputs
- Total Output - Contains the hidden states associated with all elements (time-stamps) in the input sequence
- Final Output - Contains the hidden state for the very last element of the input sequence.
As such, the Final Output doesn't provide any new information that the Total Output doesn't provide. In most cases, the Final Output can be constructed from the Total output. However, there are a few cases where that's not possible.
For our example, Total output has a size of [1,3,1]. This can be broken down as
- 1 : Number of sequences
- 3 : Number of elements in the sequence
- 1 : Number of features that define the hidden state. Directly governed by the
hidden_size
parameter
Final Output, understandably, has a size of [1,1,1] since it contains the hidden state of only the last element of the sequence.
RNN Parameters
The RNN module has 2 types of parameters, weights
and biases
. The actual number of parameters changes with the different hyperparameters that are used to define the RNN layer.
In this example, we only have 2 parameters, Wih and Whh.
Manual Computation
Given the RNN formula and the RNN layer weights, we manually compute the RNN outputs. This gives a better understanding of how the hidden states are computed internally by the RNN
module.
For the very first element, which has no preceding hidden state, we set the hidden state to be 0.
So through this basic example, we can observe that :
- RNN does a very basic computation repeatedly on all features of the given sequence
- The output at a particular timestamp depends on the output of the previous timestamp.
Adding More Features
In the next iteration, we add more features to the input sequence elements. Instead of 1, each element is now represented by a 3-element vector. The input sequence now has a shape of [1,4,3]
Given that the number of features has changed, we make the necessary changes to the RNN layer definition by setting input_size
to be 3.
In addition to the above change, I have also set bias
to be True. This would help demonstrate how bias is included in the hidden state computation.
Computing Outputs
The shapes of Total Output and Final Output are [1,4,1] and [1,1,1]. The only change here, from the previous example, is because of the length of the sequences being different. Therefore, the length of the feature vector ( hidden_size
) has no impact on the size of the output.
As before, manually computing the resultant RNN hidden state values helps us confirm the internal computation that the RNN module performs.
Increasing Hidden Size
In the next iteration, we build on the previous examepl and increase the hidden_size
parameter to 2 and its explore its effect on the computation and final output.
Increasing the hidden state size of an RNN layer helps to increase the complexity of the RNN model and allows it potentially capture more complex decision boundaries. It also allows for more expressibility for the hidden states. A hidden state represented by a vector of length 10 can capture a lot more information that a vector of length 1.
Computing Outputs
The first change that we can observe is change in the shape of the output variables. Total Output and Final Output now have shapes of [1,4,2] and [1,1,2] respectively. This is essentially due to the fact that hidden state of each element is now represented by a vector of length 2.
For the manual computation of hidden states, the computation remains mostly the same. The only difference in code was the use of Torch’s matmul
operator instead of the dot
operator which we have used previously.
Using BiDirectional RNN
BiDirectional RNNs mark a significant change from the examples that we have seen so far. While the basic RNN formula remains the same, there are some changes in computation that become much clearer on analyzing the manual computation code.
As the name suggests, a BiDirectional RNN involves RNN being applied to the input sequence in both directions. There are a lot of posts that cover, in detail, the concept behind bidirectional RNNs and why they are useful, so I won't be covering that.
The key point to keep in mind is that the bidirectional RNN computation involves 2 runs through the sequence. For ease of understanding, I refer to them as the Forward and Backward runs.
Computing Outputs
The first significant difference to be noticed are the changes to the output shapes. Total output and Final output now have the shapes of [1,4,4] and [2,1,2] respectively.
For Total output, its shape can be broken down into
- 1 : Number of sequences
- 4 : Number of elements in the sequence
- 4 : Size of the hidden state of each element.
The last shape element, denoting the size of the hidden state, is 4 because of the bidirectional nature of the RNN layer. In a bidirectional RNN, the hidden states computed by both the Forward and Backward runs are concatenated to produce the final hidden state for each element. Therefore, if the hidden_size
parameter is 3, then the final hidden state would be of length 6.
For Final Output, its shape can be broken down into
- 2 : Total number of Forward/Backward runs. Or twice the number of sequences
- 1 : This is 1 since the Final Output only takes the last element of the sequence
- 2 : Size of the hidden states for a single run. This equals
hidden_size
In Final Output, the RNN
module outputs the hidden state computed at the end of each run. Therefore, since we have a bidirectional layer, there are 2 runs and hence 2 final hidden states. Each of these hidden states will have a length that equals the hidden_size
parameter.
Model Parameters
When bidirectional
is set to True, the RNN
module also gets new parameters to differentiate between the Forward and Backward runs. The primary nomenclature of weights
and biases
remain the same. However, a new set of parameters with the same names as the previous parameters, but with an additional ‘_reverse’ suffix, are added to the system. This essentially doubles the number of parameters in the RNN layer.
Manual Computation
We start off with the Forward computation, essentially using the same procedure that we have using till now. The output of this run matches exactly with the first half (first 2 elements of each row) of the Total Output.
For the Backward run, the procedure remains the same as before. The only difference is that we now start from the very last element and move towards the first element of the sequence.
Once we have the results from both the runs, we can simply concatenate both outputs to get a resultant output that matches the Total Output accurately.
Stacked RNNs
With Stacked RNNs, we explore the num_layers
parameter of the RNN
module. Stacked RNNs can be thought of individual RNN modules stacked together, with the output of one module acting as input to the next RNN module.
For this example, I have set bidirectional
to be False in order to better explain the computation related to stacked layers. Other parameters include input_size
= 3, hidden_size
= 3 and num_layers
= 2.
Model Parameters
Since stacked RNNs can be seen as individual modules stacked together, a stacked RNN
module consists of weights
and biases
for each of the layers, with suffixes representing which layer each weight corresponds to. Since num_layers
has been set to 2, the stacked RNN
module has a total of 8 parameters - 4 weight
and 4 bias
parameters.
Computing Outputs
Total Output has a shape of [1,4,3]. This is similar to the output of a single RNN module. A point to note is that, in a stacked RNN
module, the Total Output corresponds to the hidden states computed by the very last RNN layer.
Final Output has a shape of [2,1,3]. Final Output contains the hidden state of the last element of the sequence, computed by each of the layers in the RNN
module. Therefore, since we have 1 sequence and 2 layers, the first dimension of Final Output is of length 2. If there were 2 sequences in the batch and the RNN
module had 3 layers, then the length of this dimension would have 6.
Manual Computation
For the very first layer, using the corresponding layer parameters, we can easily compute the hidden states for each of the elements using the same procedure that we have been using till now.
For the 2nd layer, and all subsequent layers, the input vector x
is replaced by the hidden states computed by the previous layer.
For layer 1,
h_current = torch.tanh(Tensor(matmul(x,wih_10.T) + bih_10 + matmul(h_previous,whh_10.T) + bhh_10))
For Layer 2,
h_current = torch.tanh(Tensor(matmul(output_1[i],wih_11.T) + bih_11 + matmul(h_previous,whh_11.T) + bhh_11))
Here, output_1
represents the hidden states computed in Layer 1.
By comparing the manually computed outputs, we can confirm that Total Output contains the hidden states computed by Layer 2 while Final Output contains the hidden state of the last element computed by Layer 1 and Layer 2.