Language models and RNN

This story covers two topics: Language models(LM) and RNN. For LM, it includes the N-gram language model and neural LM; for RNN, this story goes from vanilla RNN to vanishing gradient problem, and introduce LSTM/GRU and variants of RNN: bidirectional RNN and multi-layer RNN.

Qiurui Chen
11 min readMay 12, 2020

This story covers topics: Language models(LM) and RNN. In detail, for LM, this story goes from the N-gram language model to neural LM; for RNN, this story goes from vanilla RNN to vanishing gradient problem, and introduce LSTM/GRU and variants of RNN: bidirectional RNN and multi-layer RNN. This article is a summary of the course ‘Stanford CS224N: NLP with Deep Learning | Winter 2019 | Lecture 6 — Language Models and RNNs’.

1. Language models

Language modeling is the task of predicting what word comes next. More formally: given a sequence of words x(1),x(2), …x(t), compute the probability distribution of the next word x(t+1). A system that does language modeling is called a Language Model.

Source: the course slide

You can also think of a Language Model as a system that assigns a probability to a piece of text. For example, if we have some text x(1),x(2), …x(t), then the probability of this text (according to the Language model is) is shown on the left picture.

1.1 N-gram language models

For example, let’s complete the sentence the students opened their ____. How to learn a Language Model (LM)? The pre- deep learning method is learning an n-gram Language Model. An n-gram is a chunk of n consecutive words. Unigrams could be the, student, openedand their. Bigrams could be the students, students openede, and opened their. The idea is collecting statistics about how frequent different n-grams are, and use these to predict the next word.

N-gram LM. Source: the course slide
calculating n-grams and (n-1)-gram probabilities. Source: the course slide

First, we make a simplifying assumption: x(t+1) depends only on the preceding n-1 words. We can get these n-grams and (n-1)-gram probabilities by counting them in some large corpus of text.

4-gram LM example. Source: the course slide

For example, suppose we are learning a 4-gram LM, the sentence is as the proctor started the clock, the student opened their ___. Suppose that in the corpus: students opened their occurred 1000 times, students opened their booksoccurred 400 times. So P(books|students opened their)=0.4. Suppose students opened their exams occurred 100 times, then P(exams|students opened their)=0.1

sparsity problems with n-gram language models. Source: the course slide

The problem of n-gram LM is sparsity. Increasing n makes sparsity problems worse, typically we can not have n bigger than S. Increasing n or increasing corpus increases model size!

generating text with an n-gram LM. Source: the course slide

You can also use n-gram LM to generate text, which is surprisingly grammatical. But incoherent. We need to consider more than three words at a time if we want to model language well. But increasing n worsens sparsity problem and increases model size.

1.2 A neural LM

a fixed-window neural LM. Source: the course slide

A fixed-window neural LM improves over n-gram LM. There is no sparsity problem and you do not need to store all observed n-grams. There are still some remaining problems: 1) the fixed window is too small; 2) enlarging window enlarges W; 3) Window can never be large enough; 4) X¹ and X² are multiplied by completely different weights in W. No symmetry in how the inputs are processed. We need a neural architecture that can proceed with any length input.

2. RNN

2.1 Vanilla RNN

To address the need for a neural architecture that can proceed with any length input, we introduce RNN. The core idea is we can apply the same weight W repeatedly.

this diagram shows the most important features of RNN. Source: the course slide
A RNN LM. Source: the course slide

RNN advantages are: it can process any length input; computation for step t can (in theory) use information from many steps back; model size doesn’t increase for longer input; same wights applied on every timestep, so there is symmetry in how inputs are processed. Downsides are: recurrent computation is slow; in practice, difficult to access information from many steps back.

The loss function for RNN LM training. Source: the course slide

How to train an RNN LM? First, we need to get a big corpus of text which is a sequence of words x¹,x²,x³,…. Then we feed these inputs into RNN-LM and compute output distribution y_hat(t) for every step t. For example, predict the probability distribution of every word, given words so far. The right picture shows the loss function.

Training an RNN Language Model. Source: the course slide

However, computing loss and gradients across the entire corpus are too expensive. Stochastic Gradient Descent (SGD) allows us to compute loss and gradient for a small chunk of data, and update. So we could apply SGD into computing loss for a sentence (actually a batch of the sentence), compute gradients and update weights and repeat this process.

“ The gradient w.r.t. a repeated weight is the sum of the gradient w.r.t each time it appears”. Source: the course slide
backpropagation through time. Source: the course slide

What is the derivative of loss w.r.t the repeated weight matrix Wh? By backpropagating over timesteps i=t,..,0, summing gradients as you go, it turns out “ The gradient w.r.t. a repeated weight is the sum of the gradient w.r.t each time it appears”.

math explanation for backpropagation calculation. Source: the course slide

2.2 RNN applications

Just like an n-gram Language Model, you can use an RNN Language Model to generate text by repeated sampling. The sampled output is the next step’s input. You can train an RNN-LM on any kind of text, then generate text in that style.

RNN LM example. Source: the course slide
Evaluating Language Models. Source: the course slide

The standard evaluation metric for Language Models is perplexity. And it is equal to the exponential of the cross-entropy loss. Lower perplexity is better. Results show that RNN-LM outperforms n-gram models.

Language Modeling is a benchmark task that helps us measure our progress in understanding language. Language Modeling is a subcomponent of many NLP tasks, especially those involving generating text or estimating the probability of text: predictive typing, speech recognition, handwriting recognition, spelling/grammar correction, authorship identification, machine translation, summarization, dialogue Etc.

RNN for text classification. Source: the course slide
RNN as an encoder in question answering (left); RNN applied in speech recognition (right). Source: the course slide

2.3 vanishing gradient

vanishing gradient proof sketch. Source: the course slide
why is vanishing gradient a problem? Source: the course slide

The gradient can be viewed as a measure of the effect of the past on the future. If the gradient becomes vanishingly small over longer distances (step t to step t+n), then there is no dependency between step t and t+n in the data and we have wrong parameters to capture the true dependency between t and t+n. Due to the vanishing gradient, RNN-LMs are better at learning from sequential recency than syntactic recency, so they make this type of error more often than we’d like. For example, In the sentence The writer of the book is, is and writer are syntactic recency; In the sentence the writer of the books are , are and books (instead of writer ) are sequential recency.

2.4 exploding gradient

The exploding gradient is also a problem. If the gradient becomes too big, then the SGD update step becomes too big.

SGD. Source: the course slide

This can cause bad updates: we take too large a step and reach a bad parameter configuration (with large loss) In the worst case, this will result in inf or NaN in your network (then you have to restart training from an earlier checkpoint).

gradient clipping. Source: the course slide

Gradient clipping could fix the exploding gradient. Gradient clipping is if the norm of the gradin is greater than some threshold, scale it down before applying the SGD update. The intuition behind the gradient clipping is when finding the minima of a local function, we take a step in the same direction, but a smaller step.

2.5 Long short-term memory (LSTM)

hidden state in the vanilla RNN where the hidden state is constantly being rewritten. Source: the course slide

How to fix the vanishing gradient problem? The main problem is that it’s too difficult for the RNN to learn to preserve information over many timesteps. In a vanilla RNN, the hidden state is constantly being rewritten. How about an RNN with separate memory?

A type of RNN proposed by Hochreiter and Schmihuber in 1997 as a solution to the vanishing gradients problem. On step t, there is a hidden state h(t) and a cell state c(t). Both are vectors length n and the cell stores long-term information. The LSTM can erase, write, and read information from the cell. The selection of which information is erased/written read is controlled by three corresponding gates. The gates are also vectors length n. On each time step, each element of gates can be open (1), closed(0), or somewhere in-between. The gates are dynamic: their value is computed based on the current context.

We have a sequence of inputs x(t), and we will compute a sequence of hidden states h(t), and cell states c(t).

LSTM. Source: the course slide

How does LSTM solve vanishing gradient? The LSTM architecture makes it easier for the RNN to preserve information over many timesteps. E.g. If the forget gate is set to remember everything on every timestep, then the info in the cell is preserved indefinitely. By contrast, it is harder for vanilla RNN to learn a recurrent weight matrix Wh that preserves info in a hidden state. LSTM doesn’t guarantee that there is no vanishing/exploding gradient, but it does provide an easier way for the model to learn long-distance dependencies

LSTMs obtain real-world success. In 2013–2015, LSTMs stated achieving state-of-art results. Successful tasks include handwriting recognition, speech recognition, machine translation, parsing, image captioning. LSTM became the dominant approach. However, from 2019, other approaches (e.g.. Transformers) have become more dominant for certain tasks. For example in WMT (an MT conference + competition), in WMT 2016, the summary report contains “RNN” 44 times; in WMT 2018, the report contains “RNN” 9 times and “Transformer” 63 times.

2.5 Gated Recurrent Units (GRU)

GRU. Source: the course slide

GRU is proposed by Cho in 2014 as a simpler alternative to the LSTM. On each time step t, we have input x(t) and hidden state h(t) (no cell state)

Researchers have proposed many gated RNN variants, but LSTM and GRU are the most widely-sued.

What is the difference between GRU and LSTM? The biggest difference is that GRU is quicker to compute and has fewer parameters. There is no conclusive evidence that one consistently performs better than the other. LSTM is a good default choice (especially if your data has [particularly long dependencies, or you have lots of training data). Rule of thumb: start with LSTM, but switch to GRU if you want something more efficient.

2.6 Vanishing/exploding gradient is not just an RNN problem

Vanishing/exploding gradient is not just an RNN problem. It can be a problem for all neural architectures (including feed-forward and convolutional), especially deep ones. Due to the chain rule/choice of nonlinearity function, the gradient can become vanishingly small as t backpropagates. Thus lower layers are learned very slowly (hard to train). To fix this problem, lots of new deep feedforward/convolutional architectures that add more direct connections (thus allowing the gradient to flow), such as ResNet, DenseNet, and HighwayNet.

ResNet (left) and DenseNet (right). Source: the course slide

Residual connections aka “ResNet”, also known as skip-connections. The identity connection preserves information by default. This makes deep networks much easier to train. Dense connections aka “DenseNet”, it directly connects everything to everything. Highway connections aka “HighwayNet”. It is similar to residual connections, but the identity connection vs the transformation layer is controlled by a dynamic gate. It is Inspired by LSTMs but applied to deep feedforward/convolutional networks.

In conclusion, though vanishing/exploding gradients are a general problem, RNNs are particularly unstable due to the repeated multiplication by the same weight matrix.

2.7 bidirectional RNNs

Since hidden states (cotextual representations) in RNN only contains information about the left (the backward) context but not the right (forward) context, and these right context might be useful in some NLP tasks (such as text classification), bidirectional RNNs were proposed.

Bidirectional RNNs. Source: the course slide

Bidirectional RNNs are only applicable if you have access to the entire input sequence. They are not applicable to Language Modeling, because in LM you only have left context available. If you do have an entire input sequence (e.g. any kind of encoding), bidirectionality is powerful (you should use it by default). For example, BERT (Bidirectional Encoder Representations from Transformers) is a powerful pretrained contextual representation system built on bidirectionality.

2.8 Multi-layer RNNs

RNNs are already “deep” on one dimension (they unroll over many timesteps) We can also make them “deep” in another dimension by applying multiple RNNs — this is a multi-layer RNN. This allows the network to compute more complex representations. The lower RNNs should compute lower-level features and the higher RNNs should compute higher-level features. Multi-layer RNNs are also called stacked RNNs.

multi-layer RNN example. Source: the course slide

High-performing RNNs are often multi-layer (but aren’t as deep as convolutional or feed-forward networks). For example: In a 2017 paper, Britz et al find that for Neural Machine Translation, 2 to 4 layers is best for the encoder RNN, and 4 layers are best for the decoder RNN. However, skip-connections/dense-connections are needed to train deeper RNNs ( e.g. 8 layers). Transformer-based networks (e.g. BERT) can be up to 24 layers. Transformers have a lot of skipping-like connections.

--

--