Guide to RNNs, GRUs and LSTMs with diagrams and equations

Bobby Cheng
10 min readSep 12, 2023

--

1. Recurrent Neural Networks

1.1. What are they?

Before understanding RNNs, let us recollect how traditional neural networks operate. Traditional neural networks are known as feedforward networks because they take in an input vector and produce a fixed-size output vector that is a probability of different classes. During the training, each input vector gets processed independently.

For illustration purposes, let us imagine using a feedforward network to predict the sentiments of movie reviews. An example of a movie review sentence would be — “That was outrageous. I think Christopher Nolan did an awfully good job with Oppenheimer”, and its target label is “positive”. When the movie review sentence (input) is fed into a feedforward network, it becomes a single vector and is processed in one go.

However, when we read the movie review sentence with our eyes, we process it word by word as we read from left to right. With every new word we read, we are fetching new information. This information would be things like heuristics, memories and feelings that form what we think is the sentiment of the movie review sentence. Hence, as new information gets collected, we subconsciously update our assessment of its sentiment. This behaviour of 1) incrementally processing new information and 2) using the new information to update the input’s sentiment is what describes the behaviour of a recurrent neural network (RNN).

RNNs are a type of neural network that sequentially processes input values in multiple steps rather than in a single step. As they iterate through each input value, it updates the ‘state’ — an entity that contains information — relative to the inputs it has read.

Let us make our understanding concrete by comparing how a character language model built with a feedforward network works differently from an RNN.

1.2. How do they work?

A character language model predicts the next character in a given sequence. Let’s say we have a given sequence that reads ‘Neura’. We can use a character language model to predict the next character ‘l’ to form the word ‘Neural’.

If we use an N-gram model, it takes the last N characters to predict the next character. So, if N = 2, it will use the letters ‘r’ and ‘a’ to make a prediction. But imagine using our human deduction skills to predict the next character of ‘Neura’ with only the last two characters. It would be impossible unless we know the full sequence. So, we could set N = 5 to use the entire given sequence to more accurately predict ‘l’. However, N-gram models use a lot of RAM and space. Hence, it is not the most efficient method for scaling or training on large training data.

RNNs, on the other hand, are more suitable. Here’s why.

The given character sequence ‘Neura’ is still treated as a single input in the RNN. However, it is no longer processed in a single step. Instead, the network internally loops over each character. As the RNN loops, it does the following operations: it combines the previous ‘state’ (hₜ₋₁) and the current input (xₜ), to obtain the current output (yₜ) and the current ‘state’ (hₜ). Then, this current ‘state’ (hₜ) gets used in the next time step, t+1, and the process loops till the end of the input sequence.

The given character sequence ‘Neura’ is still treated as a single input to the network. However, it is no longer processed in a single step.

Here are 2 diagrams to further unpack the operations in an RNN. The first diagram mathematically and visually depicts these operations in a single time step t of an RNN.

  • Firstly, hₜ₋₁ and xₜ are concatenated together and multiplied with a weight matrix parameter Wₕ. This is then put through an activation function to produce hₜ.
  • Finally, hₜ is multiplied with another weight matrix parameter Wᵧₕ. This is then put through another activation function to produce yₜ.
Figure 1 — Architecture and equations of a vanilla RNN

This second diagram zooms out from the operations happening in a single time step t, to show how the inputs and outputs of an RNN work sequentially across t-1, t and t+1. Notice how information is carried through multiple steps for a single input.

Figure 2 — Sequential steps of a single input in an RNN

As such, an RNN is a ‘for loop’ that reuses values that were calculated in the previous iteration of the sequence. This unique operation means RNNs can track dependencies that span further than the last N characters of an N-gram model and use less RAM and space.

To wrap up this section, here is a picture that visually depicts the difference in operations of an N-gram character model and an RNN character model. It will also intuitively show you how an RNN captures each letter’s dependencies and even remembers previous characters up till the beginning of the sequence. The colours show how much information of each character is retained as the sequence gets longer.

Figure 3 — The differences in Bi-gram and RNN character model operations

1.3. Cost Function of GRUs

Recall that for three or more prediction classes, we use the cross entropy loss function. However, when it comes to RNNs, we adapt the entropy loss function to measure the average cost in each time step.

This type of cost function is what’s called a teacher-forcing behaviour. A teacher-forcing behaviour is a training technique for RNNs to use ground truth as input rather than the model output from a prior time step as input. This technique makes training much faster. Otherwise, the repeated errors from early time steps will propagate down each time step and lengthen the training process.

1.4. Shortcomings of RNNs

While RNNs capture dependencies within a short range and are more lightweight, they struggle to capture long-term dependencies (notice the amount of information that is retained in Figure 3), and are prone to vanishing or exploding gradients.

Backwards propagation involves using the chain rule and partial derivatives to calculate the cost function in relation to the weights, i.e. there would be a product of these partial derivatives. If one partial derivative is lesser than 1, then the contribution goes to zero (vanishing). If one partial derivative is more than 1, then the contribution goes to infinity (exploding).

When vanishing or exploding gradients occur, they make training very tough. Long training times result in poor performance and can produce poor accuracy. This is where Gated Recurrent Units (GRUs) and Long Short Term Memory (LSTM) networks come into the picture. They are variants of the RNN that have additional operations to overcome the shortcomings of RNN.

In the next sections, we will cover the operations in a GRU and an LSTM.

2. Gated Recurrent Unit (GRU)

GRUs are modified versions of an RNN. They have additional operations known as update and relevance gates to allow the network to update and get relevant information even over long sequences. Gates are mathematical operations that simply control the flow of information.

2.1. How do GRUs work?

Figure 4 — Architecture of a GRU
Figure 5 — Mathematical equations in a GRU

Like an RNN, a GRU takes in 2 inputs at a time — the current input (xₜ) and the previous hidden state (hₜ₋₁). As mentioned, they are also known for having 2 types of gates — the update (Γᵤ) and relevance (Γᵣ) gates.

The first 2 equations, in darker blue, are the update and relevance gates that are computed with sigmoid activation functions that produce a vector of values between 0 and 1. These 2 gates determine which information from the previous hidden state (hₜ₋₁) should be updated with current information (xₜ), and which information from the previous hidden state is relevant.

Then, the 3rd equation in lighter blue takes place. This equation produces the hidden state candidate (h’ₜ) and it is calculated using the relevance gate (Γᵣ). This is used to determine the information stored from the past and is generally called the memory component in a GRU cell.

Subsequently, the 4th equation in red takes place. This produces a new hidden state (hₜ) that depends on the update gate (Γᵤ) and the hidden candidate state (h’ₜ). Whenever the update gate is 0, the information from the previous state (hₜ₋₁) is maintained. When the update gate is 1, the information from the previous state is forgotten and updated with values from the new hidden state candidate. This is how the most relevant information gets passed from one time step to another.

Lastly, a ŷ prediction is computed using the current hidden state (hₜ).

Put together, these operations allow the GRU to learn what type of information to keep and what to override at different time steps; helping to preserve important information for longer.

3. Long Short Term Memory

LSTM is another modified version of an RNN. However, unlike the RNN and GRU which have 2 types of gates, an LSTM has 3 types of gates that are called the forget, input and output gates. In addition, it has one more information state which is called the cell state.

Francois Chollet, the author of ‘Deep Learning with Python’, describes this cell state as follows:

‘imagine this as a conveyor belt running parallel to the sequence you’re processing. Information from the sequence can jump onto the conveyor belt at any point, be transported to a later time step, and jump off, intact, when you need it’.

By carrying information across many time steps, it saves information for later. As a result, it impedes older information from being gradually lost.

Here’s how to differentiate the cell state and the hidden state of an LSTM. The cell state stores the internal memory of the LSTM. It does not necessarily load information from the immediate previous event. The hidden state, on the other hand, carries information from the immediate previous event and is overwritten at every step.

3.1. How do LSTMs work?

Figure 6 — Architecture of an LSTM
Figure 7— Mathematical Equations in LSTM

As mentioned above, LSTMs take in 3 inputs at a time — the current input, the previous hidden state, and the previous cell state.

The 1st equation is the forget gate (fₜ). This gate is computed with a sigmoid activation function to produce a vector of values between 0 and 1. This determines what information to discard or retain from the previous cell state (cₜ₋₁).

The 2nd and 3rd equations are the input gate (iₜ) and candidate cell state (gₜ) respectively. Together, they decide what new information would be stored in the new cell state (cₜ). The input gate decides what information (xₜ) is relevant to add from the current time step. The candidate cell state is a new set of values that could be added to the new cell state. This candidate cell state is applied with the ‘tanh’ activation to help regulate the network which improves model performance.

The 4th equation is when we update the old cell state (cₜ₋₁) to the new cell state (cₜ) using the previous three equations. By multiplying the old cell state with the forget gate (fₜ), we forget certain information from the old cell state. Then, we update the cell state by performing a point-wise addition with the new and relevant information from the input (iₜ) and candidate cell state (gₜ).

The 5th equation is the output gate (oₜ). The output gate determines what the next hidden state (hₜ) should be. This equation is very similar in calculation to the forget and input gates.

The 6th equation gives us the new hidden state. It transforms information from the new cell state (cₜ) and passes it through the output gate (oₜ) to determine what information the hidden state (hₜ) should carry.

Lastly, a ŷ prediction is computed using the current hidden state (hₜ).

Put together, these operations allow the LSTM to choose which information is relevant to remember and forget during sequence processing.

Conclusion

In this article, I covered what and how RNNs function, listed the shortcomings of vanilla RNNs, highlighted how GRUs and LSTMs solve for the shortcomings of vanilla RNNs, and explained how GRUs and LSTMs operate. It is worth remembering that the unique properties of GRUs and LSTMs are their gating mechanisms known as gates. It is these gates that make them highly useful in deep learning applications such as music composition, image captioning, speech recognition, and more.

In my next article, I’ll show how I built a text generator with GRU to generate fictional bible verses.

About the Author

Bobby Cheng is passionate about using data and software to create value for others. He has experience using ML (i.e. geospatial analytics, natural language processing, ensemble classifications) to research and build solutions for citizens.

In addition, as the previous lead data scientist for the ASEAN region of Amazon Web Services (AWS), he has experience developing software programs and engineering the new generation of ML solutions for AWS’s Commercial Sales business.

Please reach out to Bobby on LinkedIn.

Resources

Chollet, F. (2017). Deep learning with Python. Manning Publications.

Karpathy, A. (2015). The Unreasonable Effectiveness of Recurrent Neural Networks.

DeepLearning.AI. Natural Language Processing with Sequence Models.

Sherstinsky, A. (2020). Fundamentals of Recurrent Neural Network (RNN) and Long Short-Term Memory (LSTM) network.

Brownlee, J. (2017). What is Teacher Forcing for Recurrent Neural Networks?

Olah, C. (2015). Understanding LSTM Networks.

--

--