Understanding Scaled Dot-Product Attention in Transformer Models

Prashant S
7 min readJun 3, 2024

--

The Transformer model has become a game-changer in natural language processing (NLP). Its secret tool? A mechanism called self-attention, or scaled dot-product attention. This innovative approach allows the model to focus on relevant parts of the input sequence when processing each word, unlike traditional models that treat all words equally. In this article, we’ll break down how self-attention works step-by-step, using a clear example to make the concepts easier to grasp.

Here’s what’s improved:

  • Stronger opening: “Game-changer” is a more engaging way to describe the Transformer’s impact.
  • Active voice: Replacing “known as” with “called” makes the sentence more active and engaging.
  • Clearer explanation: The text clarifies how self-attention differs from traditional models.
  • Emphasis on understandability: “Break down” and “clear example” highlight the focus on easy comprehension.

Embeddings: Representing Words as Vectors

In a Transformer model, each word is represented as a vector of numbers, known as an embedding. These embeddings capture the semantic meaning of the words. Let’s consider a simple example with the following embeddings:

embeddings = {
'the': np.array([0.1, 0.2, 0.3]),
'cat': np.array([0.4, 0.5, 0.6]),
'sat': np.array([0.7, 0.8, 0.9]),
'on': np.array([1.0, 1.1, 1.2]),
'mat': np.array([1.3, 1.4, 1.5])
}

Suppose our input sentence is “the cat sat on the mat”. The corresponding embedded tokens would be:

embedded_tokens = np.array([embeddings[word] for word in 
['the', 'cat', 'sat', 'on', 'the', 'mat']])

This results in:

embedded_tokens = np.array([
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9],
[1.0, 1.1, 1.2],
[0.1, 0.2, 0.3],
[1.3, 1.4, 1.5]
])

Self-Attention Mechanism

The goal of the self-attention mechanism is to determine which words in the input sequence are relevant to each word. This involves three steps:

  1. Compute dot products between queries and keys.
  2. Scale the dot products.
  3. Apply softmax to obtain attention weights.
  4. Use the attention weights to compute a weighted sum of the values.

Detailed Step-by-Step Explanation

  1. Queries, Keys, and Values: In the simplest case, we use the same embeddings for queries (Q), keys (K), and values (V):
Q = K = V = embedded_tokens

2. Matrix Multiplication (Dot Product): We compute the dot product of the query matrix Q and the transpose of the key matrix K:

matmul_qk = tf.matmul(Q, K, transpose_b=True)

3. Scaling the Dot Products: We scale the dot products by dividing by the square root of the dimension of the key vectors (dk = 3):

dk = tf.cast(tf.shape(K)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

The result in:

scaled_attention_logits = [
[0.081, 0.185, 0.289, 0.392, 0.081, 0.496],
[0.185, 0.445, 0.705, 0.964, 0.185, 1.224],
[0.289, 0.705, 1.122, 1.538, 0.289, 1.954],
[0.392, 0.964, 1.538, 2.109, 0.392, 2.682],
[0.081, 0.185, 0.289, 0.392, 0.081, 0.496],
[0.496, 1.224, 1.954, 2.682, 0.496, 3.384]
]

4. Applying Softmax: After computing the dot products and scaling them, the next step in the attention mechanism is to apply the softmax function to these scaled values to obtain the attention weights. Let’s break down each term and the process in detail.

What are Logits?

In the context of neural networks, logits refer to the raw, unnormalized scores output by a model. These scores are typically the result of a linear transformation applied to the input features before applying an activation function.

In our case, the logits are the results of the dot products between the query and key vectors. These raw scores indicate the similarity between the query and each key, but they are not yet probabilities.

Scaling the Logits

Before applying the softmax function, we scale the logits. The reason for scaling is to prevent the softmax function from producing extremely small gradients, which can happen when the logits are too large. This scaling is done by dividing each logit by the square root of the dimension of the key vectors (denoted as ​dk):

This scaling helps stabilize the gradients during training.

Applying the Softmax Function

The softmax function is used to convert the logits into probabilities. It takes a vector of raw scores (logits) and transforms them into a probability distribution. The softmax function is defined as:

where Zi is the i-th logit, and the denominator is the sum of the exponentials of all logits.

For our scaled attention logits, the softmax function normalizes these scores, ensuring they sum to 1. This normalization helps us interpret the values as probabilities, which we call attention weights.

Detailed Step-by-Step Process

Let’s revisit the scaled attention logits from our example:

scaled_attention_logits = [
[0.081, 0.185, 0.289, 0.392, 0.081, 0.496],
[0.185, 0.445, 0.705, 0.964, 0.185, 1.224],
[0.289, 0.705, 1.122, 1.538, 0.289, 1.954],
[0.392, 0.964, 1.538, 2.109, 0.392, 2.682],
[0.081, 0.185, 0.289, 0.392, 0.081, 0.496],
[0.496, 1.224, 1.954, 2.682, 0.496, 3.384]
]

Apply Softmax Function: We apply the softmax function to each row of the scaled logits to get the attention weights. For the first row, this would be:

attention_weights[0]=softmax([0.081,0.185,0.289,0.392,0.081,0.496])

Computing this step-by-step:

  • Compute exponentials:

exp (0.081)≈1.084, exp (0.185)≈1.203, exp (0.289)≈1.335, exp (0.392)≈1.481, exp (0.081)≈1.084, exp (0.496)≈1.642

Sum of exponentials:

  • 1.084+1.203+1.335+1.481+1.084+1.642 ≈ 7.829

In order to get attention weight we need to divide each of those from 7.829:

attention_weights[0] = [0.145, 0.162, 0.171, 0.189, 0.145, 0.21]

This process is repeated for each row in the scaled attention logits to get the full attention weight matrix.

These weights show how much attention the first token “the” should pay to each token in the sequence, including itself. The token “mat” has the highest weight, indicating it is the most relevant for “the” in this context.

5. Weighted Sum of Values: Finally, we compute the output by multiplying the attention weights by the value matrix V:

output = tf.matmul(attention_weights, V)

For the first token “the”, the output is:

output[0] = 0.145*[0.1, 0.2, 0.3] + 0.162*[0.4, 0.5, 0.6] + 0.162*[0.7, 0.8, 0.9] + 0.162*[1.0, 1.1, 1.2] + 0.145*[0.1, 0.2, 0.3] + 0.223*[1.3, 1.4, 1.5]
output[0] = [0.7836, 0.8836, 0.9836]

Interpreting Attention Weights

The attention weights indicate how much focus each word should give to every other word in the sequence. Higher weights mean higher relevance. For example, in our case, the word “the” pays the most attention to the word “mat” (with a weight of 0.223).

Here is the full program for your reference:

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns

# Define word embeddings
embeddings = {
'the': np.array([0.1, 0.2, 0.3]),
'cat': np.array([0.4, 0.5, 0.6]),
'sat': np.array([0.7, 0.8, 0.9]),
'on': np.array([1.0, 1.1, 1.2]),
'mat': np.array([1.3, 1.4, 1.5])
}

# Define input sentence
sentence = ['the', 'cat', 'sat', 'on', 'the', 'mat']

# Convert sentence to embeddings
embedded_tokens = np.array([embeddings[word] for word in sentence])

# Self-attention function
def scaled_dot_product_attention(q, k, v, mask=None):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

if mask is not None:
scaled_attention_logits += (mask * -1e9)

attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights

# Q = K = V for self-attention
Q = K = V = tf.constant(embedded_tokens, dtype=tf.float32)

# Apply self-attention
output, attention_weights = scaled_dot_product_attention(Q, K, V)

# Print attention weights
print("Attention Weights:")
print(attention_weights.numpy())

# Print output
print("Output:")
print(output.numpy())

# Visualize attention weights
tokens = sentence
plt.figure(figsize=(10, 8))
sns.heatmap(attention_weights.numpy(), xticklabels=tokens, yticklabels=tokens, cmap='viridis', annot=True)
plt.xlabel('Input Tokens')
plt.ylabel('Attention given to Tokens')
plt.title('Attention Weights Heatmap')
plt.show()

Visualizing Attention Weights

To understand the model’s attention mechanism better, we can visualize the attention weights using a heatmap (see final part of code). Here’s a simple example using Matplotlib:

import matplotlib.pyplot as plt
import seaborn as sns

tokens = ['the', 'cat', 'sat', 'on', 'the', 'mat']
plt.figure(figsize=(10, 8))
sns.heatmap(attention_weights, xticklabels=tokens, yticklabels=tokens, cmap='viridis', annot=True)
plt.xlabel('Input Tokens')
plt.ylabel('Attention given to Tokens')
plt.title('Attention Weights Heatmap')
plt.show()

Here is the output:

Conclusion

Overall, the scaled dot-product attention mechanism allows the Transformer model to focus on the most relevant parts of the input for each word. By examining the attention weights, we can understand which words the model considers important, providing insights into its decision-making process. This mechanism is a powerful tool for capturing long-range dependencies and improving the model’s ability to process complex sequences.

References

--

--

Prashant S

Engineer with 15+ years exp. Passionate about tech - Python, ML, CV, generative AI, NLP. Sharing insights to empower others. Let's explore together!