Recurrent Neural Networks, the Vanishing Gradient Problem, and Long Short-Term Memory

Pranav Pillai
9 min readJul 16, 2019

--

Many things we’d like to model consist of sequential data. Stock prices, sentences in a book, daily views of a video — each of these represents a sequence of data points. Importantly, our understanding of a particular stock price, a word in a sentence, or number of views on a particular day would be incomplete without looking at the data points prior. Today’s stock price depends heavily on yesterday’s — a sentence can lack meaning without the previous paragraph — and when a video is trending, its views in the lagging 2–3 days are often the best indicator of what is to come. Thus, to accurately predict the next point in a sequence, a good model needs some notion of memory — a way to include prior data points in the current prediction.

Traditional ‘feed-forward’ neural nets have no memory from prediction to prediction. Information flows in only one-direction, and each input is evaluated in isolation. Fortunately, the advent of Recurrent Neural Networks answers our need to include past data points in new predictions, using ‘hidden state’ to capture the output of previous data points. In turn, this hidden state is combined with the current input to produce new predictions. At a high level, RNNs are neural networks that can model data with long term dependencies. Stock prices can be estimated from both current covariates and the previous prices of a stock — text can be generated with the context of the last paragraph written — and your model for video views can account for when a video is trending.

In this paper, I will first provide a technical overview of RNNs. Next, I’ll discuss the biggest limitation of conventional RNNs, termed ‘The Vanishing Gradient Problem’, before concluding with an optimization that prevents vanishing gradients. Let’s dive in!

Review of Feed-Forward Neural Networks

Feed-forward neural networks are made up of layers of nodes, each having its own weight vector and activation function. Inputs are evaluated in the following way:

  1. For each node in the first layer, the input vector x is multiplied by the node’s weight vector W. A bias term b is also often added.
  2. The resulting product of x with W is passed through an ‘activation function’, to produce an output z. Activation functions are used to extract non-linear features from x; without them, you might as well be performing linear regression.
  3. The resulting output z is then sent to the next layer of the neural network, along with the outputs of the other nodes in the first layer. Each of the nodes in the next layer will also have weight vectors and activation functions — thus, the process repeats for each layer.
  4. The final layer of the neural network combines the features output by the second to last layer in the neural network. This last layer typically features a ‘squashing function’, which reduces the outputs in the previous layer to a fixed size vector or number. This is the final output y — your guess at what the next stock price will be, a probability distribution for the next word in a sentence, or your estimate for the number of views a video will have.

Importantly, in feed-forward neural networks, data flows in only one direction. The outputs from one run of the network are quickly forgotten as new inputs come in.

Recurrent Neural Networks: The Basics

Figure 1: A Recurrent Neural Network

In a recurrent network, decisions made in earlier samples affect the decisions that will be made moving forward. To accomplish this, the outputs from previous samples are stored in ‘hidden state’, a form of memory which combines the results from all inputs seen thus far. To evaluate a new input x(t) at time t, we combine x(t) with the previous hidden state h(t-1) to generate new predictions. In Figure 1, one can see that the new hidden state h2 is generated using both its corresponding input x2, as well as the previous hidden state h1, which captures any dependencies that might have come from earlier samples. The output y2 is then a function of the hidden state h2. Mathematically, the way we combine the current input with the previous hidden state can be represented by the function:

where W is the weight matrix being applied to our input x(t), and U is the weight matrix being applied to the previous hidden state h(t-1).

Training a Recurrent Neural Network

‘Training’ a neural network means solving for the optimal weight vectors for each node, in hopes of minimizing some error function. Like feed-forward neural networks, training an RNN is done via a variant of backpropagation, called “Back Propagation Through Time (BPTT). Typically, backpropagation works by taking the partial derivative (or gradient) of the error function with respect to each weight vector, noting the direction of the gradient (e.g. whether you should increase or decrease the weights to reduce the error), and adjusting each weight vector accordingly. The term ‘backpropagation’ stems from needing to calculate the gradient with respect to later layers before one can adjust the weights in earlier layers of the neural network. For a further explanation of the concepts (and mathematics) behind backpropagation, see:

BPTT differs from generic backpropogation in that, because RNN outputs rely on both the current input as well as the previous hidden state (which, in turn, relies on all inputs prior), so too does the gradient of the error function rely on the gradients from previous inputs.

The Vanishing Gradient Problem

While training via BPTT, the gradient of the error with respect to the input weight W can be expressed as

(This happens to be the gradient for the third sample)

We note that this expression includes the derivative of the current hidden state s3 with respect to each of the k other hidden states — which in turn must be evaluated via the chain rule, e.g.

Consequently, the gradient can be expanded as

Each partial derivative of hidden states is a function of the activation function used to produce that hidden state — often the tanh and sigmoid functions. An important property of these functions is that they map input to values between -1 and 1 for tanh and 0 and 1 for sigmoid. Thus, the derivative of the hidden states is generally bounded by 1. Because the gradient is calculated using the product of these derivatives (as in the equation above), and the magnitude of each derivative of hidden states is less than 1, the result is that the entire gradient approaches 0 as the number of samples increase. With a gradient that approaches 0, each update of the weight vectors becomes smaller and smaller, leading to a neural net that does not improve after a few samples of training, with stagnant weight vectors.

This issue of gradients approaching 0 is aptly termed the ‘vanishing gradients problem’, and it was for many years the chief issue with recurrent neural networks, limiting their ability to learn correlations between inputs that were temporally distant from one another. One natural solution to this problem would be to use activation functions which are not bounded by 1 (a popular choice being the ReLU function). Another option is using Long Short-Term Memory Units (LSTMs).

Long Short-Term Memory Overview

LSTM networks are RNNs with the ability to learn long-term dependencies and conquer the vanishing gradient problem. There is an excellent piece written on them here:

In the conventional RNN, each node (or neuron) is typically a simple, single-layer module.

Figure 2: A Conventional RNN

In LSTMs, there are four neural network layers per module.

Figure 3: An LSTM RNN

The key to LSTMs is ‘cell state’ — that is, the horizontal line at the top of Figure 3. For the most part, the purpose of the cell state is to pass information down the chain without transforming it. To change information in the cell state, you have to go through gates.

Gates optionally allow information through, and are made of a sigmoid activated neural net layer and a point-wise multiplication operation. The sigmoid function maps inputs to numbers between 0 and 1, which represents the proportion of information allowed to pass through. LSTMs have 3 such gates.

LSTMs: Updating Cell State

The first network layer (the leftmost in Figure 3) in an LSTM is termed the ‘forget gate layer’. It is responsible for deciding what information to remove from cell state. To do so, for each number in the cell state vector, the forget gate outputs a number between 0 and 1. Eventually, these values will be point-wise multiplied with the current cell state, essentially determining the proportion of each current cell state value that will be kept for later hidden states. To calculate each proportion, both the current input and the previous hidden state are used (along with the usual weight vector(s) and bias term(s).

The second and third network layers in an LSTM are responsible for adding features from the current input to cell state. The first of these layers, the ‘input gate layer’, is another sigmoid layer that produces values between 0 and 1 for each number in the cell state, determining which components of cell state will be added to. The second of these layers employs a tanh activation function to produce new values to be added to cell state.

Figure 4: Calculating New Values to Add to Cell State

After values are produced by the sigmoid and tanh, they are point-wise multiplied. The effect of this is to weight the values produced by the tanh layer according to how relevant they are.

Figure 5: Updating Cell State

At this point, we can update cell state. Per the diagram above, to go from cell state C(t-1) to C(t), we multiple C(t-1) by the values generated by the forget gate, and add the new values produced by the second and third layers.

The last gate is used to determined which cell state values are funneled out as the new hidden state. Unsurprisingly, a sigmoid layer is used to output values between 0 and 1, determining which components of cell state will be part of the next hidden state (based on relevance to the input at hand). For example, if we are interpreting text data and have just seen a noun, the cell state funneled to hidden state might be optimized to produce a verb. Conventionally, a tanh function is also applied element-wise to the cell state to restrict values in the new hidden state to be between -1 and 1. This de-coupling of cell state and hidden state is noteworthy, because it means that we can remember features in cell state for long periods of time without including them in the hidden state that affects our current prediction.

Figure 6: Calculating Hidden State from Cell State

Purpose of LSTMs

The mathematics behind LSTMs can be intimidating — but the benefits are significant. LSTMs are able to “preserve the error that can be backpropagated through time and layers”. In other words, LSTMs solve the vanishing gradient problem, preventing the gradient from going to 0 as a function of the number of samples seen thus far. If not for LSTMs, many potential applications of RNNs would not be possible, and long-term dependencies would not be able to be codified. To see exactly how LSTMs solve the vanishing gradient problem, the following article offers a mathematical proof, showing how the derivative changes form — from the product of potentially very small numbers, to the sum of these numbers:

Conclusion

RNNs have broad application in machine learning. While the field of deep learning is still in development, a solid understanding of recurrent neural networks will serve as a building block towards more sophisticated, complex models. If you have any questions or corrections, please let me know; I can be reached at pranavp802@gmail.com. Appreciate the read!

Sources

  1. https://colah.github.io/posts/2015-08-Understanding-LSTMs/
  2. https://skymind.ai/wiki/lstm#long
  3. http://www.wildml.com/2015/09/recurrent-neural-networks-tutorial-part-1-introduction-to-rnns/
  4. https://medium.com/datadriveninvestor/how-do-lstm-networks-solve-the-problem-of-vanishing-gradients-a6784971a577
  5. https://www.youtube.com/watch?v=Ilg3gGewQ5U

--

--