Seq2Seq RNNs with Attention: A Quick Intuition for Understanding the Basics
Exploring the Role of the Encoder and Decoder in Seq2Seq RNNs with Attention: Quick key concepts
Seq2seq RNNs (recurrent neural networks) are a type of architecture that is used for tasks such as machine translation, language modeling, and text summarization. They consist of two RNNs: an encoder and a decoder.
The encoder processes the input sequence and produces a fixed-length vector, known as the context vector, which summarizes the input. The decoder then processes this context vector to generate the output sequence.
I’ll be giving you a quick thorough and intuitive understanding of seq2seq and attention in seq2seq.
Prerequisite
This blog is suitable for both beginners who want to get the real understanding behind and for experts who wish to refresh or brush up on the material quickly.
Basic understanding of
- working of ANNs, RNNs and CNNs
- mathematics (vectorization, dot product)
What is Attention?
Simply said, you can think of it as any technique that enables you to draw attention to a certain element while suppressing unwanted noise in a sequence of data/features.
(Technically) Masking or soft-weighting the sequence node/value.
Attention is a mechanism that allows a model to focus on specific parts of the input when generating the output. It is often used in seq2seq models, which are a type of architecture used for tasks such as machine translation, language modeling, and text summarization.
Seq2Seq RNN
Seq2Seq is a Encoder-Decoder model of RNN used in machine translation, POS tagging, NER, etc.
The Core idea of normal seq2seq model is to Encode every timestep word through RNN/LSTM to generate hidden states (Hi). Decoder takes Hf ( Hf — function of all input word) or it can take the stack of all state as initial hidden state and it will predict the word, as a function of previous decoder hidden state (S[t-1]) and output word (ŷ[t-1]) (figure 1).
However, as is well known, extended input sequences are difficult for RNN/LSTM to predict (vanishing gradients or long-term dependence issue). After a particular sequence length, their performance began to decline (figure 2). To sustain and improve this performance, we must apply an attention mechanism.
Attention in Seq2Seq RNN
In given model, we feed relevant context vector (Ci) to decoder RNN/LSTM with the previous hidden state (S[t-1]) at every timestep (t) for its prediction.
The context vector is the soft-weighted sum of all the encoder hidden state (Hi) at every timestep (t) (figure 4). This weights are denoted by A (alpha α). Note, A(t,j) (alpha α) is 2 dimensional, for all timestep (T) we have to calculate α for all encoder hidden state (Hj).
Attention method which takes all encoder hidden state (h1, h2, h3, h4) and pervious timestep decoder hidden state (S[t-1]) as input, calculates α for all hi at timestep t via SoftMax on top of separate ANN layer that is trained alongside the rest of the model.
We observe how the RNNs/LSTMs can produce the translation while concentrating on a small portion of the input sentence, thanks to the attention mechanism. alpha weight heatmap (figure 5). Observe how higher intensity alpha (α) make model understand the related portion of the English to predict French text.
Pros and Cons
Pros-
- Seq2seq models can handle input and output sequences of variable length, making them suitable for tasks such as machine translation and text summarization.
- Seq2seq models can be trained using both supervised (Machine translation, Text summarization) and unsupervised learning (Language modeling, Autoencoding) techniques.
- Seq2seq models can incorporate attention mechanisms, which allow the model to improve performance on long sequences.
Cons-
- Seq2seq models can be slower to train than other types of models, such as transformer models, because they process the input and output sequences one element at a time.
- Seq2seq models can struggle to maintain context over long sequences, especially when the input and output sequences are very different in length.
- Seq2seq models are not as well suited to parallelization as other types of models, such as transformer models, because the output of each element in the sequence depends on the output of the previous element.
For Mastering the Transformer Blocks: A Fast-Track Introduction, read the blog below.
Thanks