How to Visualize Your Recurrent Neural Network with Attention in Keras

A technical discussion and tutorial

Zafarali Ahmed
Jun 29, 2017 · 13 min read
Figure 1: Attention map for the freeform date “5 Jan 2016”. We can see that the neural network used “16” to decide that the year was 2016, “Ja” to decide that the month was 01 and the first bit of the date to decide the day of the month.

What is in this Tutorial

If you want to directly jump to the code:

What you need to know

Recurrent Neural Networks (RNN)

Figure 2: An example of an RNN Layer. The RNN Cell is applied to each element xi in the original sequence to get the corresponding element hi in the output sequence. The RNN cell uses a previous hidden state (propagated between the RNN cells) and the sequence element xi to do its calculations.

A General Framework for seq2seq: The Encoder-Decoder Setup

Figure 3: Set up of the encoder-decoder architecture. The encoder network processes the input sequence into an encoded sequence which is subsequently used by the decoder network to produce the output.
Figure 4: Use of a summary state in the encoder-decoder architecture.
Figure 5: Use of the complete encoded sequence in the decoder network.
Figure 6: Attending to objects in an image during caption generation. The white regions indicate where the attention mechanism focused on during the generation of the underlined word. From Xu, Kelvin, et al. “Show, attend and tell: Neural image caption generation with visual attention.” International Conference on Machine Learning. 2015.

The Encoder

BLSTM = Bidirectional(LSTM(encoder_units, return_sequences=True))

The Decoder

Figure 7: Overview of the Attention mechanism in an Encoder-Decoder setup. The attention mechanism creates context vectors. The decoder network uses these context vectors as well as the previous prediction to make the next one. The red arrows highlight which characters the attention mechanism will weigh highly in producing the output characters “1” and “6”.

Equations

Equation 1: A feed-forward neural network that calculates the unnormalized importance of character j in predicting character t. Equation 2: The softmax operation that normalizes the probability.
Figure 3: Calculation of the context vector for the t-th character.
Equation 4: Reset gate. Equation 5: Update gate. Equation 6: Proposal hidden state. Equation 7: New hidden state.
Equation 8: A simple neural network to predict the next character.

Code

Training

Data

Model

Visualization

from models.NMT import simpleNMTpredictive_model = simpleNMT(...)
predictive_model.load_weights(..., return_probabilities=False)probability_model = simpleNMT(..., return_probabilities=True)
probability_model.load_weights(...)

Example visualizations

Example 1: The model has learned to ignore “Saturday” during translation. We can observe that “9” is used in the prediction of “-09” (the day). The letter “M” is used to predict “-05” (the month). The last two digits of 2018 are used to predict the year.
Example 2: We can see the weirdly formatted date “January 2016 5” is incorrectly translated as 2016–01–02 where the “02” comes from the “20” in 2016

Conclusion

Acknowledgements

Datalogue

We put data into the hands of the people who need it!

Zafarali Ahmed

Written by

Computer Science, Genomics, and Machine Learning http://www.zafarali.me

Datalogue

Datalogue

We put data into the hands of the people who need it!