How LSTM networks solve the problem of vanishing gradients
A simple, straightforward mathematical explanation
When diving into the theory behind Recurrent Neural Networks (RNNs) and Long Short-Term Memory (LSTM) networks, two main questions arise:
1. Why do RNNs suffer from vanishing gradients?
2. How do LSTM networks keep the gradients from vanishing?
When I tried answering these questions, I searched for formal mathematical proofs to deepen my understanding. I had a hard time finding proofs that were understandable and clear enough for me. After reading the leading papers and blogs dealing with these questions I wrote a pair of proofs that worked for me and made me feel I really understand the problem and the solution.
The proofs in this work are almost formal ones, as I tried keeping things as clear and focused as possible, add a little intuition, and leave out parts that I felt were not important for the main direction and understanding.
RNNs and vanishing gradients
RNNs enable the modeling of time-dependent and sequential data tasks, such as stock market prediction or machine translation and text generation.
However, RNNs suffer from the problem of vanishing gradients, which hampers learning of long data sequences. The gradients carry information used in the RNN parameter update When the gradient becomes smaller and smaller, the parameter updates become insignificant which means no real learning is done.
I assume you know how an RNN looks like, but let’s have a short reminder. We will work with a simple single hidden layer RNN with a single output sequence. The network looks like this:
The network has an input sequence of vectors [X1,X2,…,Xk], at time step t the network the network has an input vector Xt. Past information and learned knowledge is encoded in the network state vectors [C1,C2,…,Ck-1], at time step t the network has an input state vector Ct-1. The input vector Xt and the state vector Ct-1 are concatenated to comprise the complete input vector at time step t, [Ct-1,Xt] .
The network has two parameter matrices Wc, Wx connecting the input vector [Ct-1,Xt] to the hidden layer. For simplicity, we leave out the bias vectors in our computations, and we denote W = [Wc,Wx].
The hyperbolic tangent is used as the activation function in the hidden layer.
The network output a single vector k at the last time step (RNNs can be modeled to output a vector on each time step, but we’ll use this simpler-to-analyze model).
Backpropagation through time in RNNs
After the RNN outputs the prediction vector Hk, we compute the prediction error Ek and use the Back Propagation Through time (BPTT) algorithm to compute the gradient
The gradient is used to update the model parameters by:
And we continue the learning process using the Gradient Descent (GD) algorithm (we use the basic version of the GD, which is sufficient for this work).
Let’s compute the gradient used to update the network parameters for a learning task that includes k time steps, we have:
Notice that since W=[Wc,Wx], Ct can be written as:
Compute the derivative of Ct and get:
Plug (2) into (1) and get:
The last expression tends to vanish when k is large, this is due to the derivative of the activation function tanh which is smaller or equal to 1.
So the network’s weights update will be:
And no significant learning will be done.
How do Long Short-Term Memory (LSTM) networks solve this?
I recommend reading Colah’s blog for an in-depth review of LSTM networks since we are only going to have a short reminder here.
An LSTM network has an input vector [Xt,Ht-1] at time step t. The network cell state is denoted by Ct. The output vectors passed through the network between consecutive time steps t, t+1 are denoted by Ht.
The LSTM has three gates that update and control the cell states, these are the forget gate, input gate and output gate. The gates use hyperbolic tangent and sigmoid activation functions.
The forget gate controls what information in the cell state to forget, given new information than entered the network.
Notice the forget gate’s output which is given by
We will use this in our proof of how LSTMs prevent the gradients from vanishing.
The input gate controls what new information will be encoded into the cell state, given the new input information.
The input gate’s output has the form:
The output gate controls what information encoded in the cell state is sent to the network as input in the following time step, this is done via the output vector Ht.
Backpropagation through time in LSTM networks
As in our RNN model, we assume that our LSTM network outputs a single prediction vector Hk on the final k-th time step. The knowledge encoded in the state vectors Ct captures long-term dependencies and relations existing in the sequential data.
The length of the data sequences can be hundreds and even thousands of time steps, making it extremely hard to learn using a basic RNN.
We compute the gradient we would use to update the network parameters, the computation is done over k time steps.
For a learning task with k time steps, as in RNNs, the gradient has form:
In an LTSM, the state vector Ct, has the form:
Compute the derivative of Ct and get:
For simplicity, we leave out the computation of:
This is of little importance to our proof, as we will see that for the gradients not to vanish, it is enough that the activations of the forget gate are greater than 0.
So we just write:
We plug this in equation (1) and get:
Now, notice equation (2) means that the gradient behaves similarly to the forget gate, and if the forget gate decides that a certain piece of information should be remembered, it will be open and have values closer to 1 to allow for information flow.
For simplicity, we can think of the forget gate’s action as:
So we get:
And the gradients do not vanish!
Summing up, we have seen that RNNs suffer from vanishing gradients caused by long series of multiplications of small values, diminishing the gradients and causing the learning process to become degenerate.
LSTMs solve the problem by creating a connection between the forget gate activations and the gradients computation, this connection creates a path for information flow through the forget gate for information the LSTM should not forget.