Understanding Transformers Part 2

Dhruv Kabra
Version 1
Published in
5 min readJul 7, 2023

In the last blog we discussed in depth the architecture of Transformers encoder architecture, here we will discuss how the decoders of Transformers work. As shown decoder is very similar to the encoder part of the transformers but has some slight modifications.
Our spotlight today is on the decoder. Let’s have a simplified diagram of Transformer architecture in its elemental components of Encoders and decoders. As discussed previously features made by the encoders are passed into the decoder.

Photo by Author

The decoder in the Transformer model can be visualized as a high-rise building of multiple layers. Each layer is a replica of the others, comprising two principal components: a ‘masked multi-head self-attention mechanism’ and a ‘position-wise fully connected feed-forward network.’ The number of these layers could vary — typically six or twelve — depending on the complexity of the task at hand.

Masked Multi-Head Self-Attention: The decoder’s masked multi-head attention mechanism in a Transformer model has the unique ability to selectively concentrate on varying parts of the input sequence while producing the output step by step. But just as we would not want a fortune teller peering into a crystal ball for answers during a quiz, it is essential that the decoder does not access information from future tokens in the input sequence. Doing so could compromise the integrity of the learning process and result in less accurate output.

How does the decoder keep itself honest? It uses a process known as masked attention. This technique effectively blinds the decoder to future words in the sequence when calculating the score matrix, which helps determine the attention weights. These weights quantify how important each input token is in generating the output.

Imagine playing a game of chess where the moves of future rounds are obscured, and all you have is the current and previous moves to strategize your next step. That’s exactly what masking does: it adds a ‘1’ to the current and previous words while appending a ‘-inf’ to future word scores in the matrix. This ensures that when we apply the softmax function to normalize the attention weights into a probability distribution, the future words in the sequence are effectively muted (they turn to zero), while the rest of the words are retained. By this process, the masked entries of future words in the score matrix ensure the decoder operates in a fair manner, with information available only up to the current token in the sequence.

(1, 0, 0, 0, 0, …, 0) => (<SOS>)

(1, 1, 0, 0, 0, …, 0) => (<SOS>, ‘Bonjour’)

(1, 1, 1, 0, 0, …, 0) => (<SOS>, ‘Bonjour’, ‘le’)

(1, 1, 1, 1, 0, …, 0) => (<SOS>, ‘Bonjour’, ‘le’, ‘monde’)

(1, 1, 1, 1, 1, …, 0) => (<SOS>, ‘Bonjour’, ‘le’, ‘monde’, ‘!’)

The masking is to make sure future words are not used and are gradually increased.

1. Position-wise Feed-Forward Networks: These networks serve to transform the output of the self-attention layer into a form that can either be used for the next layer or for the final output. They consist of two consecutive linear transformations interspersed with a ReLU activation function, applied uniformly across each position.

In addition to these two components, there’s an extra layer of multi-head attention that takes the encoder stack’s output, forming a liaison between the encoder and the decoder.

An important feature of these components is a residual connection enveloping them, followed by layer normalization. These residual connections act as a safety net, preventing the issue of vanishing gradients that can plague deeper networks.

The result of each of these components is computed as LayerNorm(x + Sublayer(x)), where Sublayer(x) signifies the function used by the component. To support these residual connections, every component and the embedding layers output data of dimension d_model=512.

Finally, the last rung of our Transformer decoder is a linear layer followed by a softmax function. The linear layer assigns scores to every potential next word in the output sequence, while the softmax function translates these scores into probabilities, giving us an indicator of the likelihood of each word being the next one in the output sequence.

In essence, the decoder in the Transformer model is a complex yet efficient tool that utilizes the powers of self-attention and feed-forward neural networks to create meaningful output based on both the encoded input and its prior outputs. By focusing on the relevant parts of the input sequence and enabling parallel processing, the decoder delivers high-quality outputs, setting new standards in machine translation.

Let’s write a simple Python Code for the decoder architecture

import torch
import torch.nn as nn

class TransformerDecoder(nn.Module):
def __init__(self, d_model, heads, num_layers, ff_dim):
super().__init__()
self.num_layers = num_layers
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
nn.TransformerDecoderLayer(d_model, heads, ff_dim)
)

def forward(self, trg, memory, trg_mask=None, memory_mask=None):
for i in range(self.num_layers):
trg = self.layers[i](trg, memory, trg_mask, memory_mask)
return trg

# Hyperparameters
d_model = 512
heads = 8
num_layers = 6
ff_dim = 2048

# Initialize Transformer decoder
decoder = TransformerDecoder(d_model, heads, num_layers, ff_dim)

# Dummy input for testing
trg = torch.rand(10, 32, 512) # (trg_len, batch_size, d_model)
memory = torch.rand(10, 32, 512) # (src_len, batch_size, d_model)

# Forward pass
output = decoder(trg, memory)
print(output.shape) # torch.Size([10, 32, 512])

In this code, d_model is the dimension of the input vectors (often words are represented as 512-dimensional vectors), heads is the number of attention heads in the multi-head attention mechanism, num_layers is the number of stacked layers in the decoder, and ff_dim is the dimension of the feedforward network model.

Each nn.TransformerDecoderLayer is a layer in the Transformer decoder, consisting of two multi-head attention layers (one being a masked multi-head attention layer and the other one being a standard multi-head attention layer taking the encoder's output as key and value input) and a feedforward neural network.

The forward function propagates the target sequence (trg) and the output of the encoder (memory) through each layer of the decoder.

Please note that PyTorch’s built-in nn.TransformerDecoderLayer and nn.TransformerDecoder handle creating the necessary masks internally to prevent the decoder from 'peeking into the future', so you don't have to manually create and pass them in this simplified example.

Contrary to other neural networks such as Recurrent Neural Networks (RNNs) and Convolutional Neural Networks (CNNs) that rely on recurrent or convolutional structures, the Transformer model leverages a unique self-attention mechanism. This mechanism allows the Transformer to perceive and comprehend the correlations among all input sequence tokens, regardless of their sequence placement.

One of the Transformer model’s distinguishing characteristics is its adaptability. It’s agnostic to the specific nature of the data, making it a versatile tool for any kind of data that can be converted into a sequence. This adaptability has allowed Transformers to be implemented beyond their initial purpose of natural language processing and finds utility in fields like computer vision, where they process sequences of image segments, and reinforcement learning, where they handle a series of states, actions, and reactions.

The inaugural version of the Transformer model was created with translation tasks in mind. Yet, the technique it employs is wider than this application. It learns to amplify input embeddings — such as words or tokens — with critical, context-sensitive information. This ensures that every token is not just an isolated entity but rather carries with it a semantic understanding derived from its surrounding context — an invaluable aspect for tasks like translation.

About the Author:
Dhruv Kabra is a Python Developer here at Version 1.

--

--