Only Numpy: Deriving Forward feed and Back Propagation in Gated Recurrent Neural Networks (GRU) — Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling — Part 1
Today, I will derive Forward Feed Process and Back Propagation on Gated Recurrent Neural Networks, and it is recommend to read the paper “Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling” — link to the paper here.
Now the left screen is all of the equation that you need, and I the blue boxed symbol is called ‘Candidate Activation’. (Will denote them as CA for short) Also for now ignore the red boxed, however this becomes very important when performing back propagation.
R(t) → Reset Gate at Time Stamp t
CA(t) → Candidate Activation Function at Time Stamp t
Z(t) → Update Gate at Time Stamp t
S(t) → State at Time Stamp t
As usual, I will use the L2 cost function, which is the very bottom equation shown in the left of the screen.
One thing very important to note is the number of Weights we have for the entire equation. As seen above, there are in total of 6 weights. This means, when performing back propagation, we need to calculate the error rate respect to 6 weights shown above!
So in total, we can expect to have total of 6 equations.
Two images, above are the graphical representation of GRU’s
Forward Feed Process
As seen above, the forward feed operation is very simple and easy to compute, at each time stamp we can calculate the error rate using the L2 cost function. And as denoted at the top of the screen, TS represents Time Stamp.
Back Propagation at Time Stamp 2
As usual with every RNN, calculating very recent back propagation respect to every weight, there are 6 of them as I said, are relatively easy.
1 → Error Rate at Time Stamp 2 respect to weight Wz
2 → Error Rate at Time Stamp 2 respect to weight Wrecz
3 → Error Rate at Time Stamp 2 respect to weight Wca
4 → Error Rate at Time Stamp 2 respect to weight Wrecca
5 → Error Rate at Time Stamp 2 respect to weight Wr
6 → Error Rate at Time Stamp 2 respect to weight Wrecr
Note: I did not do the 6 equation, try it out yourself.
Back Propagation at Time Stamp 1 respect to Wz
As seen above, and as we can already expect the back propagation process for any RNN gets messy and complicated very early on. Again, to get the total error rate when time stamp is 1, we need to get the error rate when time stamp is 1 as well as when time stamp is 2.
Also, note the red underlined symbols, all of the back propagation share that variable. This means if we calculate the shared variable first we will be able to save computation time.
Red Box → Error Rate when Time Stamp is 1
Green Box → Error Rate when Time Stamp is 2
And summation of two variables is the total gradient we need to update the weight.
The mathematical symbol is shown above, and the actual equation is shown below.
So that’s Part 1 for now, I will give you guys an update for part 2 and 3 soon… (I hope so). If any errors are found, please email me at firstname.lastname@example.org.