Build a Chatbot to answer questions via Retrieval and Text Generation

Gan Yun Tian
6 min readJan 2, 2023

--

Photo by Arnold Francisca on Unsplash

Awed by OpenAi’s ChatGPT? Wonder how it works behind the scene? In this guided code tutorial, I will show you a core component of ChatGPT — Text Generation. Specifically, you’ll build a simple Question Answering Chatbot using Text Generative Neural Networks and Document Clustering.

Components of the Chatbot that we’ll build

Most of the Chatbots used today are still based on hard rules, predefined text, and paths to guide conversations or to even answer questions.

While such an approach works, it is always better to allow a Machine Learning model, specifically 1) a Decoder Neural Network, to handle the text generation aspect of a Chatbot. A decoder will also allow for more diversified generations (we can train for this, but not in this article). Additionally, we want an easy way to insert external and newer knowledge for text generation (because pretrained language models is expensive to retrain), all well as to enhance prediction relevancy and explainability. We can achieve that via a 2) Retriever Neural Network and 3) an Index.

Core Idea: conditioning text generation on retrieved passages

This article is mainly inspired by the paper: Retrieval-Augmented Generation for Knowledge-Intensive NLP task (Lewis & Perez et al., 2020). Essentially, the core idea involves conditioning the Decoder on a given context — a relevant passage — to predict text tokens. To enable conditional generation, we will work with the three components. Architecturally wise it can be described as follow:

Step 1: offline embed text

In step 1, we will need to embed all the text in a corpus and store them in a search index for retrieval later.

Step 2: embed queries and use the top-1 retrieved passage for a conditional generation

In the next step, we need to retrieve relevant passages for text generation. In this tutorial, we seek to retrieve the top 1 passage for each query. Note that the top-1 performance tends to not be as strong as say top-k, where k >> 1. For instance, in the RAG paper, the authors retrieved several passages and conditioned the generation of each token with a unique passage. Technically, we could choose to use top-k passages by generating k answers for each query. But in this article, we focus on using top-1.

Step 3: with the retrieved passage, we condition on it

With the top 1 relevant passage, this is where the Decoder comes into play. We will construct a series of text that is concatenated as the input, and then feed it to the Decoder for conditional text generation.

Code prerequisites

Before we proceed, you will need Python 3.8+, the latest version of Nvidia CUDA and cuDNN, a modern GPU, and CPU to follow along the codes. If you lack any of those, you can use Google Colab, or any other cloud-based solution. The Python libraries that you will need: transformers, pytorch, sentencepiece, faiss-cpu, pandas, numpy (latest versions). If there are any missing libraries, Python will prompt you, then you just need to install them.

Code to create, process, and iterate through a dataset

We will be using the SQUAD training data that is downloaded from this link. Once downloaded, you will need to extract the Json file to a correct directory. The following code will help us to process and create a Pandas Dataframe consisting of the columns, Question, Passage, and Answer. The question column contains the queries while the passage column is those to be retrieved. We will chunk the original passages into a version that contains the answer text and is shorter in length.

Code temporarily removed!

Now, to put the dataset we created above into use, we need codes to support our training process. Thus, we need a Torch Dataloader object to iterate through the training examples in batches. But first, we will predefine some standard parameters in a config object.

Here is the code for creating the Dataloader.

After defining the classes, we will initialize an object out of it.

Creating the Retriever Model to embed text

With the dataset and supporting objects out of the way, we will move on to creating the three main components of our Chatbot. Firstly, we will create the Retriever class.

Notice that I utilize a pretrained model — Sentence Transformer. This model is trained on a very large corpus and thus, it can generalize very well to most domains. Note that in my code, this pretrained model will not be fine-tuned, whereas in the RAG paper, the retriever and decoder are fine-tuned end to end. Nonetheless, the Retriever works fine out of the box.

Creating the Index for retrieval

For the Index, we will utilize the faiss library from meta. Essentially, the index will support our retrieval process by searching for the nearest top k number of neighbors in a region around a query embedding.

Creating the Generator model for Conditional Text Generation

Notice that the loss function is within the Generator class. We will simply be using a Cross-Entropy loss to train just the generator.

Training Loop

Finally, we have reached the training part. The training code can be summarized as follow:

1) We embed all the passages and construct an index. We will periodically re-construct this index using the updated embeddings (technically we don't need to do this because we didn't finetune the retriever, but I still enabled it because RNG. You can disable it to save time.).

2) Next, we embed the queries and search for the top 1 passage pertaining to each query using the constructed index. Thereafter, we tokenize the passages and compute a similarity score between the retrieved passages and each query. This similarity matrix will be used to search for the top 1 relevant passage for each query.

3) Immediately after, we will tokenize the answers and inputs (query concatenated with the retrieved passage and a decoder prefix) for the generator, and utilize its internal loss function.

4) In each iteration, we backpropagate the loss value, which is an approximation for the goodness of fit of the model’s predicted values via its parameters, against the ground truth.

We can then easily train the model by using the train function. Now, let's move on to evaluate our model.

Evaluation

In the following code, we will create the objects and index for evaluation. We will use 1000 random samples from the last 5000 examples in the dataset.

Next, we will just create a simple loop to count the number of correct retrieved passages and predicted answers (out of 1000).

You will get around 400+ for correctly retrieved top-1 passages, and around 300+ correctly predicted (exact match) answers.

Some limitations

Now, I must mention that once the top-1 passage is retrieved wrongly, the generated answer will definitely be wrong. So, enhancing retriever performance is key. We could even just train the retriever and not the generator. We will need and end to end loss by utilizing the values from the similarity matrix as a signal. Similarly, in the RAG paper, their implementation is much better as they utilized a custom end-to-end loss.

Secondly, note that if we use metrics like BLEU or ROUGE for evaluation, we will get a higher performance score for the predicted answers, because we are simply checking if the generated answer is an exact match to the actual answer.

Thirdly, In a quick experiment that I did, training the models end to end by using the similarity values for loss computation/as training signal for a short while led to improvements in retrieval and generation (500+ correct for retrieval and 400+ correct for generation). In addition, if we allow top-k retrieval, improvement will increase even more. I did not implement this in the code above.

The reason is that outrightly implementing papers will disincentive the researchers that had spent countless hours and money to run the experiments.

Lastly, you can train a retriever from scratch first, using the state of the art techniques. Fix the retrievers’ parameter then finetune the generator, or allow both the retriever and generator to be further finetuned.

Conclusion

That said, if you have any comments, or require help, please feel free to reach out to me. Have fun with the code.

--

--