Normalized center loss for language modeling
Caveat: Some knowledge of recurrent neural networks is assumed.
Short Introduction to language modeling and how it works?
Q. What is language modeling?
In Language Modeling we try to predict the next word given a sequence of words. The machine learning model computes probabilities of the possible values for the next word. And a word is sampled from the generated probability distribution.
Q. How it works?
You build a recurrent neural network architecture. The model processes one word at a time and computes probabilities of the possible values for the next word. The memory state of the network is initialized with a vector of zeros and gets updated after reading each word.
The output of the RNN depends on arbitrarily distant inputs which makes back-propagation difficult. To make the learning process tractable, we truncate back-propagation to a fixed number of steps (we call num_steps
). The model is then trained on this finite approximation of the RNN. This can be implemented by feeding inputs of length num_steps
at a time and performing a backward pass after each such input block.
The metric used to evaluate language models is perplexity which is equal to exp(cross_entropy).
In continuation to my previous blog post, I wanted to see if adding a center loss would improve the resulting perplexity of the models.
Normalized center loss
I recommend readers read the first section of the previous blog post before reading this.
While using the same center loss as defined in the blog post, the center loss would increase exponentially (double in every 200 iterations), and eventually explode. To overcome this problem, I modified the center loss a bit and I call the modified version normalized center loss.
In the normalized center loss, I normalize the centers after updating them so that the norm of the center vector for every embedding is 1. This is done to prevent the loss value from exploding.
For computing the loss, the center for a word vector is scaled by the size of the corresponding embedding. This ensures that the embeddings get gradients that push them towards the center vector.
Results
I tried this new loss approach on the wikitext-2 and penn treebank datasets. Following were the results:
Results on penn treebank dataset
As it can be seen from the graphs, the perplexity improves over all lambda values tried on the validation set.
Improvement of 4 on the test set which is actually quite significant.
Results on wikitext-2 dataset
Again it can be seen from the graphs, the perplexity improves over all lambda values tried on the validation set.
Improvement of 2 on the test set which is also significant. The results here are not as impressive as for Penn treebank. I assume this is because the normalized loss function acts as a regularizer. And since wikitext-2 is much larger than the Penn Treebank dataset, effect of adding regularizer is minimized.
Conclusions
Since improvements over perplexity could be observed in the experiments on both the datasets, it can be expected that any loss function that encourages features from same class to be clustered close to one another will lead to improved accuracy and reduced overfitting.
All experiments used in this blog can be reproduced using the code given in this repo.
If you liked this article, please help others find it by clicking the little clap icon below. Thanks a lot!