The Attention Mechanism Zoo
Attention is the ability to focus on what is important. Contrary to memory: it is not about retaining as much as possible, but discarding what is not needed. 🕵️
In ML, we say that an algorithm “uses an attention mechanism” when it explicitly learns to either:
- Up-weight relevant input features (and down-weight irrelevant ones)
- Relate parts of the input within itself (or with the output)
Implicit attention is not enough!
But wait ✋, aren’t all ANN-based models already kind of doing this? Good question and yes. Provided some inputs, common ANN architectures (MLPs, ConvNets, RNNs…) automatically learn “what to focus on”. This is known as implicit attention, GLOW is a great example of it. Despite the merit of these results, however, designing an architecture that promotes this explicit feature up-weighting presents several advantages:
- Higher computational efficiency: ANNs can be smaller and trained faster
- Greater scalability: Architecture size is independent of input size
- Sequential processing of static data: The model can just remember the important information of different areas of the input data.
- Easier interpretation: As they are specifically designed for the purpose of attention.
Sounds like a good deal! 📈 How can we make this attention explicit then? Before jumping into attention mechanisms, it is key to be on the same page on input partition representations (the stuff that will get up-weighted or associated depending on the task).
The concept of “glimpses”
Glimpses are representations of partitions of the input data (aka input encoding). For instance, glimpses of common input formats could be:
- Text input: an array of 1-hot encoding of its letters. Alternatively, they could be an array of its word embeddings.
- Image inputs: different tiles of partitions of the image.
- Video input: different images that compose the video.
- Audio input: sound wave representation at each time step:
Attention as a search engine
Attention mechanisms get as input a collection of these glimpses. Their job is to select the most task-significant ones at each moment. In essence, they are just “searching” for the most relevant glimpses in a “glimpse database” (aka glimpse representation of the given input). As such, a search engine terminology is often used:
- 🔑 Keys: Description of the glimpse features. They encode the glimpse content and/or position.
- 🔒Queries: Each glimpse has an associated query that encodes how that glimpse is related to all other glimpse’s keys. The relationship between keys and queries is defined by some similarity function S(k,q), which is particular to each attention mechanism (a standard similarity function is the dot-product). Thus the queries tensor, contains as many elements as glimpses considered.
- 💰Value: Glimpse features that get up-weighted (or combined) and will be forwarded to subsequent tasks.
As an (oversimplified) analogy, when you type a query in YouTube: it gets compared to each video title (keys). The highest-matching results (highest similarity score between the query and the keys) get shown to you. In this case, the values would be the videos themselves.
Notice: Each glimpse has an associated key, an associated query, and an associated value. The “magic” of the attention mechanism is how these values are combined to provide an output.
Other important concepts:
- Attention Vector: Encodes how relevant each glimpse is wrt to the desired task.
- Attention Matrix: Encodes how relevant each glimpse is wrt to all the others. This is useful to extract sequence relations.
- Context: Output of the attention module. It is obtained by applying the attention vector/matrix to the values.
This might be a bit confusing now 🥴 but it becomes much clear after seeing how everything is computed.
Types of attention
We can classify attention mechanisms depending on how the context is created and presented to subsequent tasks:
- 🔨 Hard attention (aka non-differentiable attention): Focusses on a unique glimpse at each iteration. As it is a “hard” decision, gradient can’t be back-propagated and we rely on reinforcement learning techniques to train the glimpse-choosing policy.
- 🐑 Soft attention (aka differentiable attention): Allows end2end backprop training by smoothly combining all glimpses (through the attention vector or matrix).
Independently of these subgroups, we can also divide the methods on what does the mechanism focus on when “searching” for glimpses:
- 🧭 Location attention: The attention mechanism focuses on where to look in the input based on position.
- 🔐Associative attention (aka content-based attention): Focuses on what to look for. Relates the glimpses by their features (content), not their position.
- 🧘Introspective attention: The concept of attention is brought to the internal state of the network (aka memory).
In this post we mainly focus on soft attention methods, using associative relations. I might develop the other types in future posts.
Some Attention Mechanisms
Scaled Dot-Product attention
The similarity function used to associate keys and queries is the dot product. The attention matrix is composed of the dot product between each key-query pair. Aᵢⱼ = kᵢ ⋅ qⱼ. Then, for each glimpse, the context vector holds the weighted combination between each glimpse value and its relations.
- Scaled because the dot product of key-queries is divided by the square root of the number of features.
- Usually, in time-dependent tasks, a mask is used to prevent the model from learning from future time-steps: Obviously, at inference time, future information won’t be available. The mask basically zero’s all relations larger than the current timestep so the model cannot use them.
- Notice that there aren’t any learnable parameters: this mechanism is more a building block than a stand-alone thing.
Steps:
- Project keys, queries, and values into n different tensors (n heads)using good old fully connected (FC) layers.
- Apply scaled dot-product on each head (ups, I mean projection 🤭).
- Concat each head context and we are done :)
- The idea is that projecting into different embeddings allows for each attention head to “focus” on different aspects of the input.
- This mechanism is often used within the self-attention paradigm: The same input x is forwarded as a key, query, and value. Then, it is the FCs layer’s job to learn the key-query-value embeddings. The attention matrix encodes how the parts of the input relate to one another. In the context vector, each glimpse projection gets combined with the related ones within the same input.
- Other times is used within the cross-attention paradigm: key and value are a tensor (often coming from an encoder) and query is a different one. You query something from the values and the keys help you focus.
- We call “heads” to the different scaled dot-product attentions which are computed in parallel and later concatenated.
In future posts, we’re gonna pay attention to transformer-based architectures (pun intended), which are just the concatenation of a bunch of these attention mechanisms.
Relative-Position Multi-Head Attention
This mechanism adds positional encoding to the query projection. Thus, the network jointly considers both the context and position of the glimpses.
Additive Attention
This mechanism combines all the glimpse values into a single vector. The similarity function between keys and queries is computed by the addition of their embeddings through respective FC layers.
Location-Aware Attention
This adds an RNN flavor to Additive Attention. At each time-step a new attention vector is computed, which gets iteratively passed onto the next one. Notice that this “time-step attention” gets decoded by a 1d Conv with f (feature dimension) kernels to match the input dimension.
Final notes
- Most of the content is from: this excellent lecture by Alex Graves, and these papers: Attention Is All You Need, Neural Machine Translation by Jointly Learning to Align and Translate, Attention-Based Models for Speech Recognition, Self-Attention with Relative Position Representations
- I recommend checking out these attention implementations in PyTorch
Stay tuned! 📻 In the next episode, we’re gonna combine these attention mechanisms to create SOTA speech-to-text models 👂. Wow I just discovered that you can italic emojis!! 🤯 🤯 This shit makes me happy :)