Next Word Prediction: A Complete Guide

Sonal Sannigrahi
Linagora LABS

--

As part of my summer internship with Linagora’s R&D team, I was tasked with developing a next word prediction and autocomplete system akin to that of Google’s Smart Compose. In this article, we look at everything necessary to get started with next word prediction.

Email continues to be one of the largest forms of communication, professionally as well as personally. It is estimated that almost 3.8 billion users send nearly 300 billion emails a day! Thus, for Linagora’s open source collaborative platform OpenPaaS, which includes an email system, improving user experience is a priority. With real time assisted writing, users are able to craft richer messages from scratch and have a smoother flow as compared to the static counterpart: complete automated replies.

A real-time assisted writing system

The general pipeline of an assisted writing system relies on an accurate and fast next word prediction model. It is crucial to consider several problems in regards to building an industrial language model that enhances user experience: time on inference, model compression, transfer learning to further personalise suggestions.

At the moment, these issues are addressed by the growing use of Deep Learning techniques for mobile keyboard predictions. Leading companies have already made the switch from statistical n-gram models to machine learning based systems deployed on mobile devices. However, with Deep Learning comes the baggage of extremely large model sizes unfit for on-device prediction. As a result, model compression is key to maintaining accuracy while not using too much space. Furthermore, mobile keyboards are also tasked with learning your personal texting habits in order to make predictions that are suited to your style rather than just a general language model . Let’s take a look at how it’s done!

Next word predictions in Google’s Gboard

Some useful training corpora

In order to evaluate the different model architecture’s I will use to build my language model, it is crucial to have a solid benchmark evaluation set. For this task, I chose the Penn Tree Bank dataset which tests easily if your model is overfitting. Due to the small vocabulary size (roughly 10,000 unique words!), it is imperative to build a model that is well regularised. In addition to this, Enron’s email corpus was used to train on real email data to test predictions in the context of emails (with a much larger and richer vocabulary size of nearly 39,000 unique words).

My task was to build a bilingual model, in French and English, with intensions of generalising to other languages as well. For this, I also considered several widely available French texts. FrWac is a web archive of the whole .fr domain which is a great corpus to train on a diverse set of French sequences. For those with extensive GPU resources, the entire French wikipedia dump is also available online. With my code, I trained on a couple of short stories from Project Gutenberg (another great resource for textual data in multiple languages!)

Model Architectures

Now comes the fun part: language modelling. Generally speaking, approaches for text generation/prediction can be split into two categories: statistical based and learning based. In this article, we focus on the latter and take a deep dive into several recurrent neural network (RNN) variant architectures. For a next word prediction task, we want to build a word level language model as opposed to a character n-gram based approach however if we’re looking into completing the words along with predicting the next word then we would need to incorporate something known as beam search which relies on a character level approach. For this article, we will stick to a next word prediction without beam search.

At the time of writing this article, the models with the most success in language generation and prediction tasks are transformer based which exploit the idea of self attention. [Description]. However, these models are notoriously difficult to train without large amounts of training data as well as a good amount of GPU/TPU resources. Here we focus on the next best alternative: LSTM models.

An LSTM, Long Short Term Memory, model was first introduced in the late 90s by Hochreiter and Schmidhuber. Since then many advancements have been made using LSTM models and its applications are seen from areas including time series analysis to connected handwriting recognition. An LSTM network is a type of RNN which learns dependence on historic data for a sequence prediction task. What allows LSTMs to learn these historic dependencies are its feedback connections. For a common LSTM cell, we usually see an input gate, an output gate, and a forget gate. The weights of these gates control the flow of information in an LSTM model and thus are the parameters learnt during the training process.

Some variants of the LSTM model include a Convolution LSTM (or CNN-LSTM) and a Bidirectional LSTM (or Bi-LSTM). For our task, these variants correspond to different encodings of our input sequence. A CNN-LSTM model is typically used for image captioning tasks where the CNN portion is used to recognise features of the image and the LSTM is used to generate a suitable caption based on the features. Experimenting with using a CNN as the encoding layer yielded some interesting results however with the large number of parameters and even bigger model sizes, it was not fit for this task.

Top: CNN-LSTM, Bottom: Bi-LSTM (credit to owners)

In “Recurrent Neural Network Regularisation” by Zaremba et al., they discuss a highly regularised LSTM model with added dropouts. This model is well suited for our task by meeting the requirements of both smaller model sizes as well as high performance. It is crucial in this era of Deep Leaning that there is consideration for a tradeoff between model size and performance. Being resource efficient should be the goal rather than improving accuracy by 0.01%. The model discussed in this paper balanced the trade off between complexity and accuracy quite well, as we will discuss in the next section.

Implementation of the Chosen Model for Next Word Prediction

Below is an implementation of this model using the Deep Learning library, PyTorch. While I have previously implemented this in Tensorflow as well, the model couldn’t converge well therefore I wouldn’t recommend Tensorflow particularly for this task. Although ideally there should be no difference!

Results

So how did our final model perform on different text corpora? In order to evaluate the performance, we considered both qualitative and quantitative metrics. Qualitatively, it’s crucial to look at sentence structure (in cases where there isn’t an exact match) and context of the predictions. In the latter, we should ideally be able to see the benefit and power of LSTM based approached as compared to statistical methods. Quantitatively, we considered word perplexity (exponential of the loss) which allowed us to measure how confidently our model was predicting words, ExactMatch@N which shows at which level of probability (N=1 means the highest, N=2 means the second highest prediction, and so on) are we predicting the right word if at all, lastly we measured precision and accuracy to measure…as the name suggests… the accuracy of our model!

Our regularised LSTM performed quite well on the PTB corpus and managed to converge without overfitting, which was the primary goal. Upon increasing the input vocabulary size, by using a portion of the Enron corpus, we see even better results in terms of word level perplexity. Here are some quantitative results on both the PTB and Enron dataset:

Word level perplexity for Small regularised model (PTB) and for Medium regularised model (Enron)

Our model achieved a final test perplexity level of 117.03 on the PTB corpus with a small regularised model and a perplexity of 111.05 on the PTB corpus with a medium regularised model. As the PTB corpus was our benchmark measurement, we proceeded with the medium regularised model to train on the Enron corpus. Here we achieved a final test perplexity of 55.12 which corresponds to a test loss of 3.56 after training for 39 epochs. While in these two cases the convergence problem was solved, the same result was not met for the French text data.

Medium LSTM for French Text

As you can see above, while the perplexity for the training set is reducing the one for the validation set does not converge. This can be attributed to the training and validation sets having very different vocabulary or to the added complexity in french data with additional conjugations for verbs as compared to English. However, with more training data it is quite likely that the convergence will not be an issue.

Some predictions with the Enron corpus

Looking at some qualitative results for the model trained on the Enron corpus, we see a respectable accuracy (ExactMatch@1) of nearly 44%. Considering a vocabulary of only 39,000 a 44% accuracy is a great result! This shows that the regularised LSTM model works well for the next word prediction task especially with smaller amounts of training data.

How about using pre-trained models?

Now that we have explored different model architectures, it’s also worth discussing the use of pre-trained language models for this task. Currently, language models like Open GPT by OpenAI are trained on large amounts of Wikipedia data making them highly nuanced and suited to text prediction tasks. However models like Open GPT are not easily specialised onto different corpora, in our case emails. Furthermore, due to the extremely large number of parameters, these pre-trained language models are not suited to be deployed on device and are also not suited to be trained on device. One last problem with Open GPT specifically is that at the moment it is a pre-trained English model which makes it unsuitable for our bilingual language modelling problem.

Another popular pre-trained language model is Google’s BERT, which uses the transformer architecture discussed earlier. Unlike Open GPT, BERT is indeed available in a multilingual format and specifically for French, researchers from Inria, Facebook AI, and Sorbonne Université recently released CamemBERT (trained on 138GB of French data) is quite an exciting advancement! BERT itself was designed for masked language modelling, inspired by the Cloze task. At the moment, BERT’s power is not understood very well. While BERT can be used for a next word prediction task by setting the mask as the last word, BERT is best suited to have a left and right sequence around the mask which makes full use of its bidirectional nature. In these pre-trained language models, the process of transfer learning is quite difficult and therefore making them a general language model rather than specific use case model. Furthermore, without the specialisation the concept of personalised language models (by training on device) are even harder to achieve. This is not to say that these pre-trained models are not excellent for a first trial, you’ll surely find some interesting results with them!

Use on Inference

So now that you’ve chosen your desired language model, how do you get it to make the predictions?

Prediction on Inference (taken from Trung Tran’s tutorial in Machine Talk)

The above diagram illustrates the inference process with an input sequence and output word(s). With our language model, for an input sequence of 6 works (let us label the words as 1,2,3,4,5,6) our model will output another set of 6 words (which should try to match 2,3,4,5,6, next word) so we choose the last word from the sequence with the highest probability as our predicted next work.

Here is how the prediction function could look for this problem:

A Final Note on Model Compression

Okay so by now our model is ready to be deployed, eager to start predicting your words! But do you really expect to deploy models of sizes larger than 200 Mb on devices like mobiles? How are we supposed to maintain model accuracy yet using an insignificant amount of space in terms of model size? In comes model compression!

Industrialising NLP is a challenging task and there have been many recent developments in this field. One key breakthrough was knowledge distillation. The idea is to have two models: a teacher and a student model. Our teacher model is the complete LSTM model with all its connections and feedback loops. Now our student model is a simple 1 layer feed forward network that learns from the last softmax layer of the teacher model. This allows the student model to imitate the teacher model’s predictions while being much shallower and therefore occupying a significantly lesser amount of space! This is a very simple explanation of model compression, if you’re interested in more I suggest the following video: https://www.youtube.com/watch?v=b3zf-JylUus . The picture below illustrates this concept quite well:

Knowledge Distillation visualised

Conclusion — What next?

We’re finally at the end…or just the beginning? There are some exciting fronts still left to discover for me, starting from dealing with out-of-vocabulary and rare words to contextual encoding. Here we explored only the surface of a next word prediction model, the next steps would be work out model personalisation while still maintaining user privacy, include contractions as well as punctuations marks, and include a beam search method for autocomplete to make this model even more interactive.

I hope that after reading this article, you look at little word suggested by your phone keyboard a bit differently. The work that goes behind the scenes to make your user experience smooth and seamless is enormous and exciting! With current advancements in language understanding and generation, we are at a very rich time in technology and I personally can’t wait to see what is coming :)

Thank you to Zied Sellami for his supervision and support, and a huge thanks to the entire Linagora team for their warm welcome even while working remotely.

--

--