Next Word Prediction Using LSTMs
Detailed Implementation of classification based approach using LSTM and attention mechanism
This is in continuation with the main article Next Word Prediction using Swiftkey Data . I will discuss in detail the text features and architecture of the LSTM models over the following sections .
Feature Development
Creating sequences of texts
I have created sequences of different lengths from the tokens of the cleaned corpus maintaining the order of the tokens .
Code Sample
I have created sequences of length 2 , 4 and 7 .
Encoding the sequences and splitting into x and y
I have encoded the words in the sequences using Keras Tokenizer and splitted the sequences such that only the last word in the sequence is in y and the rest is in x .
Thus a sequence of length 2 will have first word in x and the last in y . Similarly a sequence of length 4 will have first 3 words in x and the last in y.
Code Sample
Below is the function for encoding and splitting the data
The above function also returns vocabulary of the encoded data and the index of each word encoded .
The idea here is each word in y acts as a class and my objective is to predict that class given the sequence of words in x .
Stacked LSTM Model
It is a model developed by stacking a Bidirectional LSTM layer over a LSTM layer .Following is the architecture for the LSTM Model.
I have developed 3 such models corresponding to sequences of length 2 , 4 and 7 . In each case everything remains the same except the number of unit1 and unit2 cells . The idea is to retain the sequence information of the words in x to predict the next word class in y . The output layer has a softmax activation , so I have got probability distribution of the output classes .
Code Sample
Train Results
For Sequence of length 2 :
For Sequence of length 4 :
For Sequence of length 7 :
Github Link : https://github.com/kurchi1205/Next-word-Prediction-using-Swiftkey-Data/blob/main/LSTM%20Model.ipynb
Test Results
- Sequence 2- Loss : 9.2959
- Sequence 4- Loss : 27.3950
- Sequence 7- Loss : 22.7704
Next Word Predictions
Key Idea :
- The sequence of words
(history)
is taken as input whose next word has to be predicted . - If length of
history
= 1 , then we pass it to the model corresponding to sequence length 2 and predict the class with highest probability. - If length of
history
< 4, then we pass it to the model corresponding to sequence length 4 and predict the class with highest probability . - If length of
history
< 7, then we pass it to the model corresponding to sequence length 7 and predict the class with highest probability . - If length of
history
≥ 7 ,then I takehistory = history[-6:]
and repeat step 4 .
Some Predictions :
Overall the model predicting is pretty well , but there is scope for improvement as the losses are too high .
Stacked LSTM Model with Attention
This model is quite similar to the previous one , except here I have added an attention layer to the output of Bidirectional LSTM .
Why do we need an attention layer ?
Some when we are predicting the next word for a sequence , all words of the sequence don’t contribute equally to the prediction of that word .
For example : Consider the sentence `I went to see the doctor because I was quite sick.` Suppose here , I have to predict the word sick . For that , the words went , see , and doctor will carry more weightage as they contribute more to the probability for the next word to be sick .
Thus to provide weights to each of the words in a sequence , we will need an attention layer .
Architecture of the model
Below is the architectural flowchart of the model :
Now coming to the architecture of the attention layer .
Coding Sample
For attention layer :
For the overall model :
Github Link : https://github.com/kurchi1205/Next-word-Prediction-using-Swiftkey-Data/blob/main/LSTM%20Model%20with%20Attention.ipynb
Train Results
For Sequence of length 2 :
For Sequence of length 4 :
For Sequence of length 7 :
Test Results
- Sequence 2- Loss : 7.09
- Sequence 4- Loss : 8.84
- Sequence 7- Loss : 9.24
Next Word Predictions
Key Idea : It is same as that of Stacked Model as discussed above.
Some Predictions :
Thus attention mechanism significantly improved the losses .
However I have tried out some other transformer based models to see if the result varies . Please refer to the main article for the same .