Backpropogating an LSTM: A Numerical Example

Aidan Gomez
5 min readApr 18, 2016

--

Let’s do this…

We all know LSTM’s are super powerful; So, we should know how they work and how to use them.

Syntactic notes

  • Above ⨀ is the element-wise product or Hadamard product.
  • Inner products will be represented as ⋅
  • Outer products will be respresented as ⨂
  • σ represents the sigmoid function

The forward components

The gates are defined as:

Which leads to:

Note for simplicity we define:

The backward components

Given:

  • ΔT the output difference as computed by any subsequent layers (i.e. the rest of your network), and;
  • Δout the output difference as computed by the next time-step LSTM (the equation for t-1 is below).

Find:

The final updates to the internal parameters is computed as:

Putting this all together we can begin…

The Example

Let us begin by defining out internal weights:

And now input data:

* Mohamed Challal pointed out to me that a label of 1.25 makes no sense since the outputs are a product of a tanh and sigmoid. Mohamed is completely correct!

I’m using a sequence length of two here to demonstrate the unrolling over time of RNNs.

Forward @ t=0

From here, we can pass forward our state and output and begin the next time-step.

Forward @ t=1

And since we’re done our sequence we have everything we need to begin backpropogating.

Backward @ t=1

First we’ll need to compute the difference in output from the expected (label).

Note for this we’ll be using L2 Loss:

The derivate w.r.t. x is:

So,

Now we can pass back our Δout and continue on computing…

Backward @ t=0

And we’re done the backward step!

Now we’ll need to update our internal parameters according to whatever solving algorithm you’ve chosen. I’m going to use a simple Stochastic Gradient Descent (SGD) update with learning rate: λ=0.1λ0.1.

We’ll need to compute how much our weights are going to change by:

And updating out parameters based on the SGD update function:

And that completes one iteration of solving an LSTM cell!

Errata and Frequently Asked Questions:

  • Q: in `d state_t` did you mean to use `tanh²(state_{t-1})`?
    A: No.
  • Q: you compute `d x` but never use it. Why?
    A: you would use it if there were LSTMs stacked beneath, or any trainable component leading into the LSTM. Since `x` is the input data in my example, we don’t really care about that particular gradient.
  • Q: under Backwards @ t=0: you use `delta out_{-1} = U^T d gates_1`, but it should use `gates_0`.
    A: Nice catch!

Of course, this whole process is sequential in nature and a small error will render all subsequent calculations useless, so if you catch something email me at hello@aidangomez.ca

Please feel free to share with the machine learning enthusiasts in your life!

--

--