Decoding Transformers : The Secret of Scaled Dot Product Attention

Himanshu Kale
5 min read6 days ago

--

Hello Everyone ! Welcome back to our thrilling series on Decoding Transformers! If you joined the last blogs, you know we uncovered the secrets behind the Self-Attention Mechanism. This time, we’re diving even deeper to explore the fascinating world of scaling within Self-Attention. Curious about how it all fits together? Buckle up as we take a closer look at the scaling operations that power these incredible models. Let’s start with a short recap !

Photo by Kristina Flour on Unsplash

In our last blog on self attention, we talked about how contextual embeddings are calculated from the static embeddings of a sentence “money bank grows” , wherein we took 3 trainable matrices Wq , Wk, Wv to calculate the query, key and value for each token in the sentence. Moving on we multiplied the query matrix of tokens with that of the keys , passed it through the softmax and finally multiplied the result with the value matrix to obtain Contextual embeddings . Hufff !!!

When you take a look at the complete procedure you might see that it is actually a mathematical operation of K,Q,V represented by ,
Attention(Q,K,V) = Softmax(Q.K’).V

But if look at the original paper, you will see something like this,

OK OK wait a minute !! What is this root of dk and Why is it ?? Have these questions knocked your doors also ? So let’s try to dig their answers !!

What we can see is that the scaling happens after the product of Q and K’ and before the softmax. So the answer lies there ! Generally the dimensions of the embedding vectors are large and it has been seen that if the dimensions are large then the random variable generated by the dot products of these vectors generally show very high variance. In simple terms, if the embedding dimension is small the variance would be small and as dimension increase the variance tend to increase. Let’s try to see is it true !

import numpy as np
import matplotlib.pyplot as plt

def plot_hist(dim, ax, color):

num_pairs = 1000

arrays1 = np.random.rand(num_pairs, dim)
arrays2 = np.random.rand(num_pairs, dim)

dot_products = np.einsum('ij,ij->i', arrays1, arrays2)

ax.hist(dot_products, bins=30, color=color, edgecolor='black')
ax.set_title(f'Histogram of Dot Products {dim} dimension')
ax.set_xlabel('Dot Product')
ax.set_ylabel('Frequency')

fig, axs = plt.subplots(1, 3, figsize=(18, 5))

colors = ['green', 'blue', 'red']

plot_hist(3, axs[0], colors[0])
plot_hist(100, axs[1], colors[1])
plot_hist(1000, axs[2], colors[2])

plt.tight_layout()
plt.show()

Ohh how cool !! As the dimensions increased from 3 to 100 to 1000 the variance actually increased. But what can this high variance cause ?
After the product of Q and K’ we are sending it to the softmax and softmax gives probabilities , so if the variance is high some values would be very high and some would be very small , this would result bad training. During back propogation in the training process, the higher probabilty parameters would get higher importance and the lower probability parameters would face a vanishing gradient problem resulting in no update of their weights, which is not desirable.

Now as we have understood why we need scaling lets solve how to get it.
So the solution would be simple to reduce the variance of the dot products. This can be easily achieved via scaling, but the factor of scaling is the next mystery.

As we saw that with the increase in the dimensionality the variance was increasing , their might be an relation between both of them. Take a look at the below experiment where I have tried to capture the nature of the relation of the dimensionality with the variance of dot products.

from numpy.polynomial.polynomial import Polynomial

def calculate(array_length):
arrays = np.random.rand(100, array_length)

dot_products = []
for i in range(num_arrays):
for j in range(i+1, num_arrays):
dot_product = np.dot(arrays[i], arrays[j])
dot_products.append(dot_product)

dot_products = np.array(dot_products)
variance = np.var(dot_products)

return variance

X = []
for i in range(1,513):
X.append(calculate(i))

dim = np.arange(len(X))

degree = 2
coefs = np.polyfit(dim, X, degree)
poly = np.poly1d(coefs)

x_fit = np.linspace(0, len(X) - 1, 100)
y_fit = poly(x_fit)

plt.scatter(dim, X, color='red', label='Original Data')

plt.plot(x_fit, y_fit, color='blue', label='Fitted Curve')
plt.xlabel('Dimension')
plt.ylabel('Value')
plt.title('Curve Fitting to Data')
plt.legend()
plt.show()

You will clearly see that the variance is actually increasing linearly as dimensionality is increasing. It has been mathematically been found that,
if X is the variance at dimensionality 1 then,
at dimensionality n the variance is close to nX.

This got me to a very important property in statistics we studied in our schools,

Now, what we want is that even if the dimensionality increases the value of variance must remain the same,
So mathematically if we desire to have variance equal to of that of a single dimension and based on our experiment , we might end up scaling the single dimensions with factor of root dk (dk being the embedding dimension) resulting in variance of dk. Var(single dimension).
So, this way we end up with a scaling factor of root of dk.

And that wraps up our exploration of this topic. Hope you found the blog both informative and engaging. Keep reading, keep learning, and stay curious! See you in next episode of this series , Thank You !!

--

--

Himanshu Kale

Associate Data Scientist @ Neurologic AI Systems Pvt. Ltd.| Masters from IIT Kharagpur | Machine Learning and Deep Learning Enthusiast