RNN Spelling Correction: To crack a nut with a sledgehammer

I learned basic spelling correction by Markov Chain and Noisy Channel two years ago but I haven’t implement anything at that time. One year ago, I wrote a spell correction algorithm because the one used in Pinterest was not perfect. The algorithm is based on Peter Novig’s tutorial and probabilities of unigrams and bigrams. The precision is only 70% to 80% compare with Google’s “ground truth” and I don’t know the recall. There are many challenges, for example, how to segment the word, how to understand special usernames and how to prevent over corrections.

Pinterest spelling correction UI

I’m not very happy although I shipped it in the end. I changed the algorithm and parameters but there’s no significant improvement. Until recently, I start looking at deep learning and how Google uses it. Google uses RNN for machine translation and there are some new progress. Since RNN works for machine translation for hundred thousands of words and long sentences, why we cannot use it to solve spell correction on 30 characters.

Use a sledgehammer to crack a nut

I have to admit that spell correction data complexity and model complexity is more than a hundred times smaller than machine translation. But set up a RNN for spell correction is not over engineering because there are too many deep learning libraries and examples already on the Internet.

I started to try RNN for spell correction last week. There are two parts in the RNN model. The first part is an encoding part that contains a list of cells. Each cell takes two input, one input is the current character and the other one is the output from the previous cell.

An example of RNN

The second part is a decoding part that also contains a list of cells. Each cell takes one input from the previous cell and generates a vector as the output. The vector corresponding to a character in the vocabulary.

Some other details:

  • I use Tensorflow to implement the RNN since it’s popular.
  • The input limited is 50 characters, it covers 99.9% search traffic.
  • The current model only support ‘a’ to ‘z’ and special characters.
  • Every input is one character since I don’t want to deal with tokenizer.
  • The input is the reverse of the query and since it’s easy to learn.
  • There are 2 layers in each cell.
  • The dimension for each cell and character is 512.
  • There are 64 examples per training batch.

Where’s the data from?

High quality data is the key for every machine learning problem.

I’m using anonymous search session data to extract training and testing data for the model. When I notice there are two similar queries in one search session and the second one have better result, I guess it probably a good spell correction example.

For people without good training data, Deep Spelling talked about how to generate training data from Google’s billion word dataset. Although it is different from real typos, it is good enough to train a spell correction model.

How good is our RNN?

I just play it for fun and there’s no expectation. However, it turns out the model works very well for spell correction. Around 100 iterations, there’s only funny result like “ddd hhhh”, “sssss iiihhhh” in the output. After 400 iterations, I saw a query “diy home”.

About 4000 iterations, the model can generate trigram queries like “picnic in house”. About 6000 iterations, the model achieves 90% accuracy on my testing data. It only takes a couple hours to train this model.

Some examples: typo — correction — guess

Future work

I didn’t expect that deep learning solved my problem easily. Compare with other solutions, there are a lot of advantages for RNN:

  • The model has a lot of parameters but it is easy to setup. And the model is easy to learn context relationship.
  • It is easy to handle big data.
  • There’s no prior data needed for spell correction with RNN.
  • It is easy to extend the model to other tasks such as text summarization.

I will continue train spell correction model with more data. And publish my code on Github later. In the same time, I need to write a Java library to load the model into our query understanding service. In the end, I will also implement some logic to handle long sentences so we can spell correct tweets, descriptions and other long sentences.


Recurrent Neural Networks

Understanding LSTM Networks

One clap, two clap, three clap, forty?

By clapping more or less, you can signal to us which stories really stand out.