Abstractive text summarization with Transformer model using Pytorch

Koyela Chakrabarti
21 min readOct 18, 2023

--

Text summarization is one of the many applications of natural language processing. There are two types of summarization models available. One is the extractive text summarization and the other one is the abstractive text summarization. In extractive text summarization, after the stop words are removed from the text, the frequency of occurrence of each word token is calculated. Each token is assigned a rank or weight based on frequency of occurrence, higher the frequency, greater the weight. Next, each sentence is assigned a weight by summing over the individual weights of each token present in it. The sentences are ranked per the weight and the k topmost ranking sentences are presented as the summary. While a model like this does not need any training, but this is rule based. In my personal experience, longer sentences are often selected and many a times it fails to capture the context of the text. Therefore, this article is about abstractive text summarization, which is a supervised learning model built using a transformer.

A transformer uses self-attention mechanism as it’s basis. If you are not aware of attention mechanism, I have got the basics covered in my Youtube video here. The abstractive text summarization falls under sequence to sequence learning problem, where, a variable length input sequence(text in this case) is fed into the network to output another variable length sequence (summary as in here). Therefore, to deal with this type of model, we need an encoder-decoder architecture, where the encoder encodes the variable length input sequence to a fixed length encoding and feeds to the decoder. From the fixed length encoded input sequence, the decoder outputs a variable length sequence. There are several ways in which this encoder-decoder can be modelled. For sequential data processing i.e. for text data, the model can be built using RNNs only, or adding attention mechanism like Bahdanau Attention model to capture the context information better. But in case of transformers, the use of RNNs are completely bypassed. In transformers, where self-attention mechanism is used, same token is used as the query, key and value of the attention model and with some further processing, the output is calculated. Self-attention provides a major advantage of parallel processing over RNN which is strictly sequential. Suppose there is a sequence [x1 x2 x3 x4 x5] , in RNN, the tokens are sequentially processed. Like, at any instance of time the probability of generating x3 depends on the probability of generation of x2 which again, is dependent on the probability of generation of x1. Though an indirect relationship does exist between probability of generation of x3 given x1, we cannot directly predict the probability of generating token x3 two positions from x1 disregarding x2. But in self-attention technique, each token in each iteration is considered as the query and the relationship with every other tokens present in the sequence which are considered as keys are parallelly computed. Self-attention is a special case of multi head attention mechanism that assigns weights to different elements of an input sequence considering different sub spaces when generating an output sequence. It basically extends the idea of assigning weight to a token by using multiple sets of attention weights or “heads” to capture different types of relationship in the data. It helps the model decide which parts of the input are more relevant to each part of the output. Mathematically each head can be denoted as:

where H denotes head, K, Q, V denotes key, query and value respectively. Wv, Wk and Wq are the respective weight parameters of value, key and query. Each of these heads are concatenated and fed into a fully connected layer for a linear transformation. Hence, in a nutshell, self attention allows for a shorter path between any combination of sequence positions makes it easier to learn long-range dependencies within the sequence. Let us dive deeper into the concept while dirtying our hands with the model building. The first step as always is to import the necessary libraries.

import numpy as np
import pandas as pd
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
import re
import warnings
warnings.filterwarnings("ignore")
from sklearn.model_selection import train_test_split as tts
import torch
import collections
from collections import Counter
from torch import nn
from torch.utils.data import Dataset, DataLoader
from bs4 import BeautifulSoup
import math

Next let us get some data for model training and validation. I have considered this dataset from Kaggle. The dataset contains a collection of Amazon product reviews and has columns with information on Id, ProductId, UserId, ProfileName, HelpfulnessNumerator, HelpfulnessDenominator, Score, Time, Summary, Text. From here only the relevant columns like Summary and text will be selected, the rest of the data is ignored. Also, to keep the training time reasonable I have selected the first 10000 data rows. You might consider choosing more rows.

df = pd.read_csv('Reviews.csv', nrows = 10000)
data = df[['Text', 'Summary']]
data.drop_duplicates(subset=['Text'], inplace=True) # Drop duplicate rows

Now is the time to do some initial data pre-processing which involves cleaning data, removing duplicates and deleting rows which has the summary column as blank. First of all we will use a word mapping which will map phrases like “I’ve” to “I have” and so on. A dictionary is declared with the key value and the text and summary fields are processed.

word_mapping = {"ain't": "is not","aint": "is not", "aren't": "are not","arent": "are not","can't": "cannot","cant": "cannot", "'cause": "because", "cause": "because", "could've": "could have", "couldn't": "could not",

"didn't": "did not", "doesn't": "does not", "don't": "do not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not",

"he'd": "he would","he'll": "he will", "he's": "he is", "how'd": "how did", "how'd'y": "how do you", "how'll": "how will", "how's": "how is",

"I'd": "I would", "I'd've": "I would have", "I'll": "I will", "I'll've": "I will have","I'm": "I am", "I've": "I have", "i'd": "i would",

"i'd've": "i would have", "i'll": "i will", "i'll've": "i will have","i'm": "i am", "i've": "i have", "isn't": "is not", "it'd": "it would",

"it'd've": "it would have", "it'll": "it will", "it'll've": "it will have","it's": "it is", "let's": "let us", "ma'am": "madam",

"mayn't": "may not", "might've": "might have","mightn't": "might not","mightn't've": "might not have", "must've": "must have", 'mstake':"mistake",

"mustn't": "must not", "mustn't've": "must not have", "needn't": "need not", "needn't've": "need not have","o'clock": "of the clock",

"oughtn't": "ought not", "oughtn't've": "ought not have", "shan't": "shall not", "sha'n't": "shall not", "shan't've": "shall not have",

"she'd": "she would", "she'd've": "she would have", "she'll": "she will", "she'll've": "she will have", "she's": "she is",

"should've": "should have", "shouldn't": "should not", "shouldn't've": "should not have", "so've": "so have","so's": "so as",

"this's": "this is","that'd": "that would", "that'd've": "that would have", "that's": "that is", "there'd": "there would",

"there'd've": "there would have", "there's": "there is", "here's": "here is","they'd": "they would", "they'd've": "they would have",

"they'll": "they will", "they'll've": "they will have", "they're": "they are", "they've": "they have", "to've": "to have",

"wasn't": "was not",'wasnt':"was not", "we'd": "we would", "we'd've": "we would have", "we'll": "we will", "we'll've": "we will have", "we're": "we are",

"we've": "we have", "weren't": "were not", "what'll": "what will", "what'll've": "what will have", "what're": "what are",

"what's": "what is", "what've": "what have", "when's": "when is", "when've": "when have", "where'd": "where did", "where's": "where is",

"where've": "where have", "who'll": "who will", "who'll've": "who will have", "who's": "who is", "who've": "who have",

"why's": "why is", "why've": "why have", "will've": "will have", "won't": "will not", "won't've": "will not have",

"would've": "would have", "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all",

"y'all'd": "you all would","y'all'd've": "you all would have","y'all're": "you all are","y'all've": "you all have",

"you'd": "you would", "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have",

"you're": "you are", "you've": "you have", 'youve':"you have", 'goin':"going", '4ward':"forward", "shant":"shall not",'tat':"that", 'u':"you", 'v': "we",'b4':'before', "sayin'":"saying"
}

Next, we will define a function that converts the input text into lower case, removes any text within parenthesis along with the parenthesis (since you will not mention something important using brackets), remove the quotation marks, any word which does not start with an alphabet. The words with word length less than 3 is also removed. Finally a space is added to all the remaining punctuations because otherwise the words ‘my’ and ‘my,’ will be treated as two different words when they are essentially the same. Both the text and the summary data are pre-processed.

stop_words = set(stopwords.words('english'))
def text_cleaner(text):
newString = text.lower()
newString = BeautifulSoup(newString, "lxml").text
newString = re.sub(r'\([^)]*\)', '', newString)
newString = re.sub('"','', newString)
newString = ' '.join([word_mapping[t] if t in word_mapping else t for t in newString.split(" ")])
newString = re.sub(r"'s\b","",newString)
newString = re.sub("[^a-zA-Z]", " ", newString)
tokens = [w for w in newString.split() if not w in stop_words]
long_words=[]

tokens = [w for w in newString.split() if not w in stop_words]
long_words=[]
for i in tokens:
if len(i)>=3: #removing short word
long_words.append(i)
text = " ".join(long_words).strip()
def no_space(word, prev_word):
return word in set(',!"";.''?') and prev_word!=" "
text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text)]
text = ''.join(out)
return text

data['cleaned_text'] = data['Text'].apply(text_cleaner)
data['cleaned_summary'] = data['Summary'].apply(text_cleaner)
# this step is to remove all rows that have a blank summary
data["cleaned_summary"].replace('', np.nan, inplace=True)
data.dropna(axis=0, inplace=True)

Now we need to fix the maximum length of the text and summary. If we plot the summary and text histogram, it looks like this:

Histogram plot of text and summary based on word count
Histogram of Summary and Text based on word count

Most of the reviews are below 200 words, with majority of reviews being less than even 100 words. Therefore, for ease of computation in this toy project let me set the maximum word length for text and summary to 100 and 10 words each.

max_len_text=100 
max_len_summary=10

At this point, we split the data into train and validation sets. This can be statically done with the help of simple python slicing, but I wanted to shuffle the data so have used the train test split function of the sci-kit learn library.

x_train,x_test,y_train,y_test = tts(data['cleaned_text'],data['cleaned_summary'],test_size=0.1, shuffle=True, random_state=111)

Next, we tokenize and then pad or truncate the tokens into a fixed length sequence as decided upon before. Since Pytorch is used and for this I will not use torch text, therefore we need to create the functions and the vocabulary class which is done below. In one of my previous posts in Medium I have explained in detail about how to do that. I would request the reader to go through this post for a detailed explanation. Since this is sequence to sequence learning, it is important to append the start of sequence, end of sequence token to each example text and summary for processing. Also if the maximum text length is greater than the text length for a particular example, padding tokens are to be appended. These needs to be exclusively mentioned as special tokens while creating the vocabulary.

# Tokenize function 
def tokenize(lines, token='word'):
assert token in ('word', 'char'), 'Unknown token type: ' + token
return [line.split() if token == 'word' else list(line) for line in lines]

# pading function
def truncate_pad(line, num_steps, padding_token):
if len(line) > num_steps:
return line[:num_steps] # Truncate
return line + [padding_token] * (num_steps - len(line)) # Pad

# the vocabulary class
class Vocab:
def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
# Flatten a 2D list if needed
if tokens and isinstance(tokens[0], list):
tokens = [token for line in tokens for token in line]
# Count token frequencies
counter = collections.Counter(tokens)
self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
reverse=True)
# The list of unique tokens
self.idx_to_token = list(sorted(set(['<unk>'] + reserved_tokens + [
token for token, freq in self.token_freqs if freq >= min_freq])))
self.token_to_idx = {token: idx
for idx, token in enumerate(self.idx_to_token)}

def __len__(self):
return len(self.idx_to_token)

def __getitem__(self, tokens):
if not isinstance(tokens, (list, tuple)):
return self.token_to_idx.get(tokens, self.unk)
return [self.__getitem__(token) for token in tokens]

def to_tokens(self, indices):
if hasattr(indices, '__len__') and len(indices) > 1:
return [self.idx_to_token[int(index)] for index in indices]
return self.idx_to_token[indices]

def unk(self): # Index for the unknown token
return self.token_to_idx['<unk>']
# tokenize
src_tokens = tokenize(x_train)
tgt_tokens = tokenize(y_train)
# build vocabulary on dataset
src_vocab = Vocab(src_tokens, reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = Vocab(tgt_tokens, reserved_tokens=['<pad>', '<bos>', '<eos>'])

The next step is to create the minibatches of data. For this project I have chosen a batch size of 64. But before that, we have a few more steps. First, we have to pad or truncate the text and the summary to the maximum text and summary sequence length as previously determined. As of now each record of text and sequence is appended an end-of-sequence(eos) token and then padded. The length of the each record along with the eos token is determined for each token and is stored in a vector the length of which is equal to the batch size. The transformed text array, summary array along with the valid length for both the sequences are stored in a tuple as the final data array. Here the text related matrices is addressed as source and the summary related matrices are addressed as target.

# fn to add eos and padding and also determine valid length of each data sample
def build_array_sum(lines, vocab, num_steps):
lines = [vocab[l] for l in lines]
lines = [l + [vocab['<eos>']] for l in lines]
array = torch.tensor([truncate_pad(l, num_steps, vocab['<pad>']) for l in lines])
valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
return array, valid_len

src_array, src_valid_len = build_array_sum(src_tokens, src_vocab, max_len_text)
tgt_array, tgt_valid_len = build_array_sum(tgt_tokens, tgt_vocab, max_len_summary)
data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)

For fetching the data for training the neural network in Pytorch, we need to create the dataset object by passing the data array from the above set by instantiating the TensorDataset class. Next, this object along with the batch size is passed into the DataLoader class of the Pytorch which creates minibatch instances of the training dataset.

# create the tensor dataset object 
def load_array(data_arrays, batch_size, is_train=True):
dataset = torch.utils.data.TensorDataset(*data_arrays)
return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)
batch_size = 64
data_iter = load_array(data_arrays, batch_size)

Now let us start building the transformer model which is in fact the building block of sophisticated models in NLP(Natural Language Processing) and Machine Learning like BERT, GPT and more. There are two parts to the transformer, an encoder and a decoder. An encoder is made up of multiple similar blocks where there are two main sublayers, the self-attention pooling layer and a position wise feed forward network. Since this is a deep architecture, ResNet design has been employed here, where residual connections are added to both sub layers. Resnet is basically a collection of nested function classes, where the larger function classes contain the smaller ones that increases the expressive power of the network by training them to the identity function f(x) = x. In fact each additional nested layer should more easily contain the identity function as one of it’s elements. Resnet addresses the vanishing gradient problem and also allows shortcut connections within the network, that addresses information loss, since every layer linearly transforms the input which results in some distortion of information. In ResNet we basically try to adjust the weights so that (f(x) — x) gets closer to 0 (i.e. we try to calculate the residual, hence the name). In a particular encoder block, the queries, keys and values are from the previous block. For the first block, the original input tokens are embedded, then positional encodings are added to each token and fed into the first block for processing. The position wise feed forward network transforms the representation at all sequence positions using the same Multi layered perceptron (MLP) and hence using the same weight parameters. A ReLU activation employed here adds some non-linearity and helps in modeling complex relationships between tokens.

Transformer Architecture as depicted in the original 2017 paper “Attention is All You Need” by Vaswani et al.

After the residual connection is added, layer normalization has been done. This is done because the magnitude of parameters in one layer can widely differ from the other layer. Therefore, the difference in magnitude of outputs of the connected layers will also be varied, therefore normalization along each row is done to adjust that .The flowchart of how the encoder works is shown in the diagram below.

Network Flowchart of Encoder part of the Transformer Model
Work flow diagram of the encoder part of the transformer

The self-attention class using Multi-head attention is shown here.

# The main class 
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = DotProductAttention(dropout)
self.w_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.w_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.w_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.w_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
queries = transpose_qkv(self.w_q(queries), self.num_heads)
keys = transpose_qkv(self.w_k(keys), self.num_heads)
values = transpose_qkv(self.w_v(values), self.num_heads)
if valid_lens is not None:
valid_lens = torch.repeat_interleave(valid_lens, repeats = self.num_heads, dim=0)
output = self.attention(queries, keys, values, valid_lens)
output_concat = transpose_output(output, self.num_heads)
return self.w_o(output_concat)

# Function to transpose the linearly transformed query key and values
def transpose_qkv(X, num_heads):
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
X = X.permute(0, 2, 1, 3)
return X.reshape(-1, X.shape[2], X.shape[3])

# For output formatting
def transpose_output(X, num_heads):
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)

# The dot product attention scoring function
class DotProductAttention(nn.Module):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
scores = torch.bmm(queries, keys.transpose(1, 2))/math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
# Here masking is used so that irrelevant padding tokens are not considered
# while calculations

def sequence_mask(X, valid_len, value=0):
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32)[None, :] < valid_len[:, None] #device=X.device
X[~mask] = value
return X
# the irrelevant tokens are given a very small negative value which gets
# ignored in the subsequent calculations
def masked_softmax(X, valid_lens):
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)

The position-wise feed forward network code is as shown below.

class PositionWiseFFN(nn.Module):
def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_output, **kwargs):
super(PositionWiseFFN, self).__init__(**kwargs)
self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_output)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))

Positional encoding is used in both encoder and decoder to maintain the position information of both the source and target tokens. The following code block defines that class. A positional embedding matrix P is used which has the same dimension as the input matrix i.e. the token matrix and X+P is computed. The rows in P matrix corresponds to position in a sequence and columns represent different positional encoding dimensions. The encoding dimensions are coded as follows :

where i stands for the row and 2j for the column. The class is defined as follows ;

class PositionalEncoding(nn.Module):
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1)/torch.pow(10000,torch.arange(0, num_hiddens,2, dtype=torch.float32)/num_hiddens)
self.P[:,:, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)

def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)

Now we are ready to define the encoder part. There are two classes defined, one for the recurrent block of the encoder and the actual transformer encoder that encodes the block structure within.

# class for the block structure within 
class EncoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input,
ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
super(EncoderBlock, self).__init__(**kwargs)
self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens,num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)

def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))

# the main encoder class
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs):
super(TransformerEncoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias))

def forward(self, X, valid_lens, *args):
X = self.pos_encoding(self.embedding(X)*math.sqrt(self.num_hiddens))
self.attention_weights = [None]*len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[i] = blk.attention.attention.attention_weights
return X

The decoder layer of the transformer model is almost similar to the encoder with one extra attention layer where the output of the previous decoder layer serves as the query and the encoder outputs serve as the key and the values. So let us decode the decoder module. First, the target tokens are embedded and a positional encoding is added to them. Next they are made to pass through an n-block layered structure where each block does the following. As in encoder, self attention is applied to the encoded tokens and as per the Resnet design. Here the attention mechanism is termed as masked self attention model, because, though during training the model has the tokens for all the positions, but during prediction phase, it will have only the tokens generated so far. Therefore a decoder valid length needs to be passed and based on that, the tokens need to be masked. The outputs of the self-attention layer along with the encoded outputs are passed through the AddNorm layer. Next, a second self-attention layer accepts the input as query from the previous layer along with the encoder output as keys and values and encoding valid length for masking function. The output along with the output from the first attention layer is passed through the second Addnorm layer. After that, the feed forward network translates the output of the second add norm layer. The original output of the second Add norm along with the output of feed forward layer is passed through the final Add norm layer. The output of the final layer of the block along with the state information is passed to the subsequent decoding blocks for further processing. The output of the last decoder block is made to pass through a dense (linear as in Pytorch) to obtain the outputs. The flowchart of the working of decoder is shown as below.

The decoder work flow in transformer model
Work flow diagram of the decoder part of the Transformer Model

For simplicity of diagram, both encoder and decoder workflow is shown taking single encoder or decoder blocks respectively. For multiple blocks, the output of one block is fed as input to the next block.

As done in encoder, the decoder will have two classes defined, one for the decoder block and the main transformer decoder as is shown below.

class DecoderBlock(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs):
super(DecoderBlock, self).__init__(**kwargs)
self.i = i
self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)

def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
if state[2][self.i] is None: # true when training the model
key_values = X
else: # while decoding state[2][self.i] is decoded output of the ith block till the present time-step
key_values = torch.cat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
dec_valid_lens = torch.arange(1, num_steps+1, device = X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state

# The main decoder class
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs):
super(TransformerDecoder, self).__init__(**kwargs)
self.num_hiddens = num_hiddens
self.num_layers = num_layers
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_layers):
self.blks.add_module("block"+str(i),
DecoderBlock(key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.Linear(num_hiddens, vocab_size)

def init_state(self, enc_outputs, enc_valid_lens, *args):
return [enc_outputs, enc_valid_lens, [None]*self.num_layers]

def forward(self, X, state):
X = self.pos_encoding(self.embedding(X)*math.sqrt(self.num_hiddens))
self._attention_weights = [[None]*len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
self._attention_weights[0][i] = blk.attention1.attention.attention_weights
self._attention_weights[1][i] = blk.attention2.attention.attention_weights
return self.dense(X), state

def attention_weights(self):
return self._attention_weights

The final Transformer class containing both the encoder and decoder is modelled as follows;

class Transformer(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder

def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]

If you have GPU installed in your system, then the following code snippet chooses that as the device of execution. Otherwise the default, CPU is chosen by the code snippet.

def get_device(i=0):
if torch.cuda.device_count() >= i+1:
return torch.device(f'cuda:{i}')
else:
return torch.device('cpu')
device = get_device()

Now let us instantiate the model and initialize the parameters

num_hiddens, num_layers, dropout, num_steps = 32, 2, 0.1, 10
ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
key_size, query_size, value_size = 32, 32, 32
norm_shape = [32]
encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,num_layers, dropout)
decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size, num_hiddens,norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,num_layers, dropout)
net = Transformer(encoder, decoder)
nn.init.xavier_uniform_(net.weight) # for initialising the weights of the fully connected layers in the model

A complex neural network often needs to deal with exploding gradient problem. The following code snippet is a function that keeps the value of the gradient within a maximum limit theta, which is 1 in our case. This function is used in the training sequence function defined later.

def grad_clipping(net, theta):
if isinstance(net, nn.Module):
params = [p for p in net.parameters() if p.requires_grad]
else:
params = net.params
norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
if norm > theta:
for param in params:
param.grad[:] *= theta / norm

The class accumulator stores sum accumulated over n variables. An object of this class will be used to hold the sum of losses and the number of tokens processed for calculating the final loss while training.

class Accumulator:
def __init__(self, n):
self.data = [0.0] * n

def add(self, *args):
self.data = [a + float(b) for a, b in zip(self.data, args)]

def reset(self):
self.data = [0.0] * len(self.data)

def __getitem__(self, idx):
return self.data[idx]

We will be using Cross Entropy loss for evaluation, but we do not want to consider the padding tokens for loss calculation. So the sequence mask function is used to exclude the unwanted tokens. Therefore we get a new class that inherits the original cross entropy loss class with the masking function used. We add a weight factor 0 to the padding tokens and this weight factor is multiplied to the loss calculated by the cross entropy loss and hence they are cleared.

class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
# `pred` shape: (`batch_size`, `num_steps`, `vocab_size`)
# `label` shape: (`batch_size`, `num_steps`)
# `valid_len` shape: (`batch_size`,)
def forward(self, pred, label, valid_len):
weights = torch.ones_like(label)
weights = sequence_mask(weights, valid_len)
self.reduction='none'
unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(pred.permute(0, 2, 1), label)
weighted_loss = (unweighted_loss * weights).mean(dim=1)
return weighted_loss

While training the model, we use teacher forcing where the original target sequence is fed into the decoder. This expedites the training process which converges with fewer iterations.

def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device):    
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
loss = MaskedSoftmaxCELoss()
net.train()
for epoch in range(num_epochs):
metric = Accumulator(2) # Sum of training loss, no. of tokens
for batch in data_iter:
optimizer.zero_grad()
X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],device=device).reshape(-1, 1)
dec_input = torch.cat([bos, Y[:, :-1]], 1) # Teacher forcing
Y_hat = net(X, dec_input, X_valid_len)
l = loss(Y_hat, Y, Y_valid_len)
l.sum().backward() # Make the loss scalar for `backward`
grad_clipping(net, 1)
num_tokens = Y_valid_len.sum()
optimizer.step()
with torch.no_grad():
metric.add(l.sum(), num_tokens)
print(f"Done with epoch number: {epoch+1}") # optional step
print(f'loss {metric[0] / metric[1]:.3f} on {str(device)}')

We will now be training the transformer model created with 250 iterations and 0.005 learning rate as follows.

lr = 0.005
num_epochs = 250
train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device)

The next code snippet provides the function to predict an input sequence

def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,device, save_attention_weights=False):
# Set `net` to eval mode for inference
net.eval()
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]
enc_valid_len = torch.tensor([len(src_tokens)], device=device)
src_tokens = truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
# Unsqueeze adds another dimension that works as the the batch axis here
enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
enc_outputs = net.encoder(enc_X, enc_valid_len)
dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
# Add the batch axis to the decoder now
dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
output_seq, attention_weight_seq = [], []
for _ in range(num_steps):
Y = net.decoder(dec_X, dec_state)[0]
# We use the token with the highest prediction likelihood as the input
# of the decoder at the next time step
dec_X = Y.argmax(dim=2)
pred = dec_X.squeeze(dim=0).type(torch.int32).item()
# Save attention weights
if save_attention_weights:
attention_weight_seq.append(net.decoder.attention_weights)
# Once the end-of-sequence token is predicted, the generation of the output sequence is complete
if pred == tgt_vocab['<eos>']:
break
output_seq.append(pred)
if len(output_seq)<2:

if len(output_seq)==1:
return ''.join(tgt_vocab.to_tokens(output_seq[0])), attention_weight_seq
else:

return "No output!", attention_weight_seq
else:
return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq

Finally, time to taste the pudding (kidding). Let us check how the model has fared so far to summarize the text. We take the first few examples of our test set that we kept aside. The code will print both the predicted and the actual summary. It is not quite possible for the tokens to match exactly, but let us check if both convey the same meaning. We will consider the first 10 rows of the validation data reserved in the original data set.

sample = x_test[:10]
actual = y_test[:10]
for s, a in zip(sam, act):
pred_sum, _ = predict_seq2seq(net, s, src_vocab, tgt_vocab, 10, device)
print("PREDICTED : {} ::=>".format(pred_sum), end='\t')
print("ACTUAL : {}".format(a))

The output of the above was ;

I was a little surprised with the third prediction, where target says “yummy” and predictor says “great coffee” so I checked the actual text and this is what I got,

The third sample i.e. with ID 2384 actually says something positive about a coffee.

If you have made this far, thank you for going through the post. Though there is still some scope for improvement in this model, like it tends to use a lot of superlatives like “great” and “best” when “good” just does it’s work, but overall the model works pretty well. You could fine tune parameters like changing the number of attention heads (just bear in mind number of heads must divide the number of hidden units), increasing the number of epochs or changing the learning rate for improving the model. Hope you found the tutorial useful.

--

--