Understanding recurrent networks (Part 1 — Simple RNN)

Jeremy
7 min readOct 8, 2023

--

Generated via Dall-E

Recurrent neural networks (RNNs) work well on problems where temporal relationships are important. Examples include stock market prediction, language translation, and generating music given some initial notes. In each case, the input contains critical temporal data which informs what the most likely output should be.

Why build a recurrent network?

Suppose you have a task to predict the next stock price given an initial sequence of prices. While you could pass the entire sequence to the first layer of the network — in doing so the order would not be preserved. We want to present this ordering to the model as certain positions (such as the last position in this case) may serve as strong predictive signals.

This problem has parallels to convolutional neural networks where there are positional relationships in the input we want to pass on to our model.

Basics of RNNs

RNNs solve this positional information loss problem by sequentially passing each input_t to the network. RNNs consist of multiple cells where each cell consists of one or more layers which transform the input(s) to produce one or more outputs.

In addition, a hidden state is passed to each cellcell_i from the previous cell cell_(i-1). This hidden state is a vector which encapsulates useful information obtained from previous cells given input_t.

After the last cell is reached, the next input_(t+1) is passed to the RNN and the computation is repeated. During this second iteration, information on input_t initially comes from the hidden state in the corresponding cell obtained in the first iteration.

Therefore, the model learns to affect the final prediction given its corresponding i_th input and the hidden state from the previous iterations and previous cells. The model’s relationship between cells is why this architecture is known as a recurrent NN.

The final layer is then able to make a prediction based on the final input and the network’s representation of all previous inputs.

This is loosely analogous to the idea of recursion in programming where at each level some task is performed and the algorithm relies on previously executed tasks to solve the entire problem.

Let’s walk through an example

There are a lot of diagrams out there for what the architecture looks like. After implementing it from scratch, there were inconsistencies in number of parameters between what I thought and what libraries such as PyTorch produced for their RNN’s. So pay close attention!

For RNN_cell_1 at t=1, the input is input_t1 and a 0 vector hidden state. The output of this cell then serves as input to RNN_cell_2 while the hidden state serves as the hidden state to the same cell RNN_cell_1 when it is subsequently passed input_t2 at t=2. In the case of the simple RNN, both the output and hidden state are identical but this differs in more complex recurrent networks such as the LSTM.

At the last time step, the output from the last cell is a vector which is passed through one or more linear layers to produce the final prediction. In a regression problem the output dimension of the last linear layer will be 1, while in a multi-class classification problem the output dimension could be the number of classes there are.

Simple RNN diagram — by Jeremy

RNN single cell implementation

Let’s implement this idea using PyTorch to better understand what is going on.

Simple RNN Cell

First we will define a cell in the network. This is the fundamental building block of our RNN. While there are different definitions of a recurrent cell, the below cell consists of two separate dense layers. The layer ih will be passed the original input if it is cell_1 or the output from the previous cell. The layer hh will be passed the hidden state from the same cell at the previous time step. Having two dense layers allows the model to learn separate transformations for both the input and hidden state.

class RecurrentCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()

self.ih = nn.Linear(input_size, hidden_size)
self.hh = nn.Linear(hidden_size, hidden_size)

self.activation = nn.ReLU()

def forward(self, x, h):
res = self.activation(self.ih(x) + self.hh(h))
return res, res

As we have designed the layers such that both output a tensor of size hidden_size, we can sum the output from both layers. This is a form of sum pooling, summarizing the model’s key information from the transformed input and previous layer’s output.

We then use a ReLU activation to enable the model to learn non-linear complex transformations.

In the case of our simple recurrent network, both the cell’s output and hidden state output are the same. Hence we return res twice via a tuple.

Stacking multiple RNN cells

To implement the sequential flow between cells we stack 3 RecurrentCell’s in a ModuleList. We need to treat the first cell differently since it is the only cell which takes the original input.

In addition, since there is no hidden state at t=1, in our forward function we initially set h as a tensor of zeroes. At subsequent time steps, cells will be passed the updated h which will be the hidden state from the same cell at the previous time step.

Assuming a batch size of 1, we then take the final output of the ModuleList which is a 1 by 5 tensor and pass it to a dense layer to get our final 1 by 1 tensor output.

class Model(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()

self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers

layers = []
if num_layers >= 1:
layers.append(RecurrentCell(input_size, hidden_size))
for _ in range(1, num_layers):
layers.append(RecurrentCell(hidden_size, hidden_size))

self.recurrent_cells = nn.ModuleList(layers)
self.dense = nn.Linear(hidden_size, input_size)

def forward(self, x):
batch_size, seq_len = x.size(0), x.size(1)
h = torch.zeros(self.num_layers, batch_size, self.hidden_size)

rnn_out = torch.zeros(batch_size, self.num_layers, self.hidden_size)

for t in range(seq_len):
out = x[:, t, :]
for i in range(len(self.recurrent_cells)):
h_prev = h[i, :, :]
out, h_updated = self.recurrent_cells[i](out, h_prev)
h = h.clone()
h[i, :, :] = h_updated
rnn_out[:, i, :] = out

out = self.dense(rnn_out[:, -1, :])

return out

model = Model(1, 5, 3)

To inspect the parameters of the model we will call it with a 3 by 1 tensor. You can imagine this to be stock prices on 3 consecutive days.

summary(model, (1, 3, 1))

==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Model [1, 1] --
├─ModuleList: 1-1 -- --
│ └─RecurrentCell: 2-1 [1, 5] --
│ │ └─Linear: 3-1 [1, 5] 10
│ │ └─Linear: 3-2 [1, 5] 30
│ │ └─ReLU: 3-3 [1, 5] --
│ └─RecurrentCell: 2-2 [1, 5] --
│ │ └─Linear: 3-4 [1, 5] 30
│ │ └─Linear: 3-5 [1, 5] 30
│ │ └─ReLU: 3-6 [1, 5] --
│ └─RecurrentCell: 2-3 [1, 5] --
│ │ └─Linear: 3-7 [1, 5] 30
│ │ └─Linear: 3-8 [1, 5] 30
│ │ └─ReLU: 3-9 [1, 5] --
│ └─RecurrentCell: 2-4 [1, 5] (recursive)
│ │ └─Linear: 3-10 [1, 5] (recursive)
│ │ └─Linear: 3-11 [1, 5] (recursive)
│ │ └─ReLU: 3-12 [1, 5] --
│ └─RecurrentCell: 2-5 [1, 5] (recursive)
│ │ └─Linear: 3-13 [1, 5] (recursive)
│ │ └─Linear: 3-14 [1, 5] (recursive)
│ │ └─ReLU: 3-15 [1, 5] --
│ └─RecurrentCell: 2-6 [1, 5] (recursive)
│ │ └─Linear: 3-16 [1, 5] (recursive)
│ │ └─Linear: 3-17 [1, 5] (recursive)
│ │ └─ReLU: 3-18 [1, 5] --
│ └─RecurrentCell: 2-7 [1, 5] (recursive)
│ │ └─Linear: 3-19 [1, 5] (recursive)
│ │ └─Linear: 3-20 [1, 5] (recursive)
│ │ └─ReLU: 3-21 [1, 5] --
│ └─RecurrentCell: 2-8 [1, 5] (recursive)
│ │ └─Linear: 3-22 [1, 5] (recursive)
│ │ └─Linear: 3-23 [1, 5] (recursive)
│ │ └─ReLU: 3-24 [1, 5] --
│ └─RecurrentCell: 2-9 [1, 5] (recursive)
│ │ └─Linear: 3-25 [1, 5] (recursive)
│ │ └─Linear: 3-26 [1, 5] (recursive)
│ │ └─ReLU: 3-27 [1, 5] --
├─Linear: 1-2 [1, 1] 6
==========================================================================================
Total params: 166
Trainable params: 166
Non-trainable params: 0
Total mult-adds (M): 0.00
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
==========================================================================================

We see the first RecurrentCell consists of 40 parameters due to the shape of the original input while subsequent cells consist of 60 parameters as these deal only with cell outputs. The number of parameters for the first cell can be calculated as: input_size*hidden_size + hidden_size + hidden_size*hidden_size + hidden_size = 1*5 + 5 + 5*5 + 5 = 40. The number of parameters for all other cells can be calculated as: 2*(hidden_size*hidden_size + hidden_size) = 2*(5*5 + 5) = 60.

The parameters for the final dense layer is computed from: 5*1 + 1 = 6.

Therefore this gives us 166 total parameters for our simple 3 layer RNN.

You can find the notebook implementing this RNN architecture here.

The problem with simple RNNs

RNNs similar to the one shown above tend to struggle with problem domains where long range dependencies are critical. This is because information from earlier time steps is being transformed at each time step meaning certain information may get lost.

One analogy for this is to consider cells at each time step playing a game of Chinese whispers. Each cell wants to pass on the information it thinks is important but by the final time step, parts of the original information may be lost.

Generated via Dall-E

An example of a long range dependency could be generating a paragraph where early on, the subject is identified to be male. Several sentences later, the model needs to remember this so it will use the correct pronoun ‘he’ when referring to the subject.

Let’s look at how LSTMs aimed to tackle this problem.

References

--

--