Paper Summary: Regularizing and Optimizing LSTM Language Models

Mike Plotz Sage
5 min readNov 26, 2018

--

Part of the series A Month of Machine Learning Paper Summaries. Originally posted here on 2018/11/16, with better formatting.

Regularizing and Optimizing LSTM Language Models (2017) Stephen Merity, Nitish Shirish Keskar, Richard Socher

This paper is a collection of best practices for LSTM language models, at least as of middle of 2017. Which means it was probably out of date at the end of 2017, and here we are a year later. Still, I expect much of this is useful / usable, so here we go.

Here’s the problem: language models have lots of parameters and are prone to overfitting. We’d like to add regularization, but naive application of batch norm and dropout don’t really work. What do? There’s been lots of work on this, so this paper collects the approaches and tries them all together (Gal & Ghahramani 2016 contributes a lot of material here). The paper also covers a new variant of SGD (Non-monotonically Triggered Average SGD, or NT-ASGD).

Weight-dropped LSTM

Use DropConnect (Wan 2013): dropout is applied to the hidden-to-hidden weight matrices U* of the LSTM (so, the recurrent connections only) before the forward/backward pass. The main benefit here is that this is compatible with optimized LSTM implementations, so slow custom implementations are unnecessary.

Optimization with NT-ASGD

Momentum variants of SGD don’t work well in RNNs, turns out (I knew this at one point and then forgot). So the authors had to look farther afield for improvements to optimization methods. There’s an older method called ASGD (averaged stochastic gradient descent) that returns the average of several iterations of SGD (rather than just the final iteration), after a certain trigger point T. Problem is, T has to be tuned manually. So they came up with a way to trigger automatically when a validation metric stops improving.

The algorithm is in the paper, but here’s the gist: do SGD. Periodically (say, every epoch) check for stagnation of the validation metric, setting the trigger T if the metric hasn’t improved (but only after n = 5 stagnation checks). Once the trigger is set there’s no further need to check for stagnation. Return the average of the iteration weights from T onwards:

where k is the total number of SGD iterations.

Variable Length BPTT

Using the same value for backprop through time (BPTT, the number of recurrent timesteps) is less than ideal: the same words will always be at the beginning of each window, leading these early words to systematically receive less gradient information. This is like Alan Turing’s bicycle chain wearing down because it’s always coming into contact with the same gear teeth.

The simple fix is to vary the BPTT length, while keeping the average length close to the longest efficient length seq. The first adjustment is to sometimes use seq/2. The second adjustment is to add jitter by drawing from a normal distribution. They used N (70, 5) with p = .95 and N (35, 5) with p = .05.

Variational Dropout and Embedding Dropout

For input and output weights they use variational dropout (Gal & Ghahramani 2016), which uses the same dropout mask for all timesteps (hidden-to-hidden weights use DropConnect, as mentioned above). Each minibatch example still gets a different mask, though, and the masks change for every forward/backward pass. This guarantees that any non-dropped input or output will be routed through everywhere — not doing this makes the training task too hard (my interpretation), especially as BPTT gets longer. If I understand DropConnect correctly the same thing applies there.

Embedding dropout (also from G&G 2016) is the same idea applied to the embedding matrix.

Weight Tying

This is just sharing weights between the embedding and softmax layer to reduce the number of parameters. This encodes prior knowledge over the one-to-oneness of the input and output. From Inan 2016.

Independent Embedding and Hidden Sizes

What it says on the tin. They used 400-d embeddings and 1150 hidden units. For some reason most previous models use the same number.

(Temporal) Activation Regularization

AR is L2 weight decay on the LSTM outputs (after dropout is applied). TAR is a “slowness” regularizer, preventing hidden state changes from happening too fast, implemented as a L2 penalty on the difference between LSTM outputs on consecutive timesteps. Only applied to the final layer.

Pointer Models

There’s a thing called a “continuous cache pointer” which I guess is a way of doing attention. I didn’t follow that bunny trail, so I don’t know how it works, but I gather that there’s reason to believe that pointer models might not help when you also do weight tying and other tricks from this paper. So they tried combining all the things and it turns out that pointer models still help substantially, but not quite as much as they do on their own.

(I’m actually not sure how to think about this. Once you’ve applied optimization A, I’d expect optimization B not to help quite as much as it would by itself, because something about low hanging fruit. This doesn’t imply there’s anything particularly bad about how A and B interact. Of course if B doesn’t help at all, that’s another story.)

Ablation

All modifications recommended in the paper had a measurable effect (it would be strange if not). Weight-dropping (DropConnect) had the biggest single impact.

These kinds of grab-bag papers are pretty nice to read, and seem useful, but I wonder about the propriety of writing a paper consisting almost entirely of others’ methods. As long as there’s some kind of novelty (here NT-ASGD and the ablation studies) I guess it’s considered cool. I suppose I’m doing something similar — with even less novelty — by writing up summaries. I’m not claiming the result is a paper, though.

Inan et al 2016 “Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling” https://arxiv.org/abs/1611.01462

Gal & Ghahramani 2016 “A Theoretically Grounded Application of Dropout in Recurrent Neural Networks” https://arxiv.org/abs/1512.05287

Wan et al 2013 “Regularization of Neural Networks using DropConnect” http://proceedings.mlr.press/v28/wan13.html

--

--

Mike Plotz Sage

yet another bay area software engineer • learning junkie • searching for the right level of meta • also pie