A Transformer Chatbot Tutorial with TensorFlow 2.0

A guest article by Bryan M. Li, FOR.ai

The use of artificial neural networks to create chatbots is increasingly popular nowadays, however, teaching a computer to have natural conversations is very difficult and often requires large and complicated language models.

With all the changes and improvements made in TensorFlow 2.0 we can build complicated models with ease. In this post, we will demonstrate how to build a Transformer chatbot. All of the code used in this post is available in this colab notebook, which will run end to end (including installing TensorFlow 2.0).

This article assumes some knowledge of text generation, attention and transformer. In this tutorial we are going to focus on:

  • Preprocessing the Cornell Movie-Dialogs Corpus using TensorFlow Datasets and creating an input pipeline using tf.data
  • Implementing MultiHeadAttention with Model subclassing
  • Implementing a Transformer with Functional API
input: where have you been ?
output: i m not talking about that .
input: i am not crazy , my mother had me tested .
output: i m not sure . i m not hungry .
input: i m not sure . i m not hungry .
output: you re a liar .
input: you re a liar .
output: i m not going to be a man . i m gonna need to go to school .

Sample conversations of a Transformer chatbot trained on Movie-Dialogs Corpus.


Transformer, proposed in the paper Attention is All You Need, is a neural network architecture solely based on self-attention mechanism and is very parallelizable.

tf.keras model plot of our Transformer

A Transformer model handles variable-sized input using stacks of self-attention layers instead of RNNs or CNNs. This general architecture has a number of advantages:

  • It makes no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects.
  • Layer outputs can be calculated in parallel, instead of a series like an RNN.
  • Distant items can affect each other’s output without passing through many recurrent steps, or convolution layers.
  • It can learn long-range dependencies.

The disadvantage of this architecture:

  • For a time-series, the output for a time-step is calculated from the entire history instead of only the inputs and current hidden-state. This may be less efficient.
  • If the input does have a temporal/spatial relationship, like text, some positional encoding must be added or the model will effectively see a bag of words.

If you are interested in knowing more about Transformer, check out The Annotated Transformer and Illustrated Transformer.


We are using the Cornell Movie-Dialogs Corpus as our dataset, which contains more than 220k conversational exchanges between more than 10k pairs of movie characters.

“+++$+++” is being used as a field separator in all the files within the corpus dataset.

movie_conversations.txt has the following format: ID of the first character, ID of the second character, ID of the movie that this conversation occurred, and a list of line IDs. The character and movie information can be found in movie_characters_metadata.txt and movie_titles_metadata.txt respectively.

u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L194’, ‘L195’, ‘L196’, ‘L197’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L198’, ‘L199’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L200’, ‘L201’, ‘L202’, ‘L203’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L204’, ‘L205’, ‘L206’]
u0 +++$+++ u2 +++$+++ m0 +++$+++ [‘L207’, ‘L208’]

Samples of conversations pairs from movie_conversations.txt

movie_lines.txt has the following format: ID of the conversation line, ID of the character who uttered this phase, ID of the movie, name of the character and the text of the line.

L901 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He said everyone was doing it. So I did it.
L900 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ As in…
L899 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ Now I do. Back then, was a different story.
L898 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ But you hate Joey
L897 +++$+++ u5 +++$+++ m0 +++$+++ KAT +++$+++ He was, like, a total babe

Samples of conversation text from movie_lines.txt

We are going to build the input pipeline with the following steps:

  • Extract a list of conversation pairs from move_conversations.txt and movie_lines.txt
  • Preprocess each sentence by removing special characters in each sentence.
  • Build tokenizer (map text to ID and ID to text) with TensorFlow Datasets SubwordTextEncoder.
  • Tokenize each sentence and add START_TOKEN and END_TOKEN to indicate the start and end of each sentence.
  • Filter out sentences that contain more than MAX_LENGTH tokens.
  • Pad tokenized sentences to MAX_LENGTH
  • Build tf.data.Dataset with the tokenized sentences

Notice that Transformer is an autoregressive model, it makes predictions one part at a time and uses its output so far to decide what to do next. During training this example uses teacher-forcing. Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.

The full preprocessing code can be found at the Prepare Dataset section of the colab notebook.

i really , really , really wanna go , but i can t . not unless my sister goes .
i m workin on it . but she doesn t seem to be goin for him .

Sample preprocessed conversation pair


Like many sequence-to-sequence models, Transformer also consist of encoder and decoder. However, instead of recurrent or convolution layers, Transformer uses multi-head attention layers, which consist of multiple scaled dot-product attention.

Attention architecture diagrams from Attention is All You Need

Scaled dot product attention

The scaled dot-product attention function takes three inputs: Q (query), K (key), V (value). The equation used to calculate the attention weights is:

As the softmax normalization being applied on the key, its values decide the amount of importance given to the query. The output represents the multiplication of the attention weights and value. This ensures that the words we want to focus on are kept as is and the irrelevant words are flushed out.

Implementation of a scaled dot-product attention layer

Multi-head Attention Layer

The Sequential models allow us to build models very quickly by simply stacking layers on top of each other; however, for more complicated and non-sequential models, the Functional API and Model subclassing are needed. The tf.keras API allows us to mix and match different API styles. My favourite feature of Model subclassing is the capability for debugging. I can set a breakpoint in the call() method and observe the values for each layer’s inputs and outputs like a numpy array, and this makes debugging a lot simpler.

Here, we are using Model subclassing to implement our MultiHeadAttention layer.

Multi-head attention consists of four parts:

  • Linear layers and split into heads.
  • Scaled dot-product attention.
  • Concatenation of heads.
  • Final linear layer.

Each multi-head attention block takes a dictionary as input, which consist of query, key and value. Notice that when using Model subclassing with Functional API, the input(s) has to be kept as a single argument, hence we have to wrap query, key and value as a dictionary.

The input are then put through dense layers and split up into multiple heads. scaled_dot_product_attention() defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step. The attention output for each head is then concatenated and put through a final dense layer.

Instead of one single attention head, query, key, and value are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

Implementation of multi-head attention layer with model subclassing


Transformer architecture diagram from Attention is All You Need

Transformer uses stacked multi-head attention and dense layers for both the encoder and decoder. The encoder maps an input sequence of symbol representations to a sequence of continuous representations. Then the decoder takes the continuous representation and generates an output sequence of symbols one element at a time.

Positional Encoding

Since Transformer doesn’t contain any recurrence or convolution, positional encoding is added to give the model some information about the relative position of the words in the sentence.

The formula for calculating the positional encoding

The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of words in a sentence. So after adding the positional encoding, words will be closer to each other based on the similarity of their meaning and their position in the sentence, in the d-dimensional space. To learn more about Positional Encoding, check out this tutorial.

We implemented the Positional Encoding with Model subclassing where we apply the encoding matrix to the input in call().

Implementation of Positional Encoding with Model subclassing

Transformer with Functional API

With the Functional API, we can stack our layers similar to Sequential model but without the constraint of it being a sequential model, and without declaring all the variables and layers we needed in advance like Model subclassing. One advantage of the Functional API is that it validate the model as we build it, such as checking the input and output shape for each layer, and raise meaningful error message when there is a mismatch.

We are implementing our encoding layers, encoder, decoding layers, decoder and the Transformer itself using the Functional API.

Checkout how to implement the same models with Model subclassing from this tutorial.

Encoding Layer

Each encoder layer consists of sublayers:

  • Multi-head attention (with padding mask)
  • 2 dense layers followed by dropout
Implementation of an encoder layer with Functional API

We can use tf.keras.utils.plot_model() to visualize our model. (Checkout all the model plots on the colab notebook)

Flow diagram of an encoder layer


The Encoder consists of:

  • Input Embedding
  • Positional Encoding
  • N of encoder layers

The input is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.

Implementation of encoder with Functional API

Decoder Layer

Each decoder layer consists of sublayers:

  • Masked multi-head attention (with look ahead mask and padding mask)
  • Multi-head attention (with padding mask). value and key receive the encoder output as inputs. query receives the output from the masked multi-head attention sublayer.
  • 2 dense layers followed by dropout

As query receives the output from decoder’s first attention block, and key receives the encoder output, the attention weights represent the importance given to the decoder’s input based on the encoder’s output. In other words, the decoder predicts the next word by looking at the encoder output and self-attending to its own output.

Implementation of decoder layer with Functional API


The Decoder consists of:

  • Output Embedding
  • Positional Encoding
  • N decoder layers

The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.

Implementation of a decoder with Functional API


Transformer consists of the encoder, decoder and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.

enc_padding_mask and dec_padding_mask are used to mask out all the padding tokens. look_ahead_mask is used to mask out future tokens in a sequence. As the length of the masks changes with different input sequence length, we are creating these masks with Lambda layers.

Implementation of Transformer with Functional API

Train the model

We can initialize our Transformer as follows:

After defining our loss function, optimizer and metrics, we can simply train our model with model.fit(). Notice that we have to mask our loss function such that the padding tokens get ignored, also we are writing our custom learning rate.


To evaluate, we have to run inference one time-step at a time, and pass in the output from the previous time-step as input.

Notice that we don’t normally apply dropout during inference, but we didn’t specify a training argument for our model. This is because training and mask are already built-in for us, if we want to run model for evaluation, we can simply call model(inputs, training=False) to run the model in inference mode.

Transformer evaluation implementation

To test our model, we can call predict(sentence).

>>> output = predict(‘Where have you been?’)
>>> print(output)
i don t know . i m not sure . i m a paleontologist .


Here we are, we have implemented a Transformer in TensorFlow 2.0 in around 500 lines of code.

In this tutorial, we focus on the two different approaches to implement complex models with Functional API and Model subclassing, and how to incorporate them.

If you want to know more about the two different approaches and their pros and cons, check out when to use the functional API section on TensorFlow’s guide.

Try using a different dataset or hyper-parameters to train the Transformer! Thanks for reading.