Transformer-XL

Shaurya Goel
5 min readJul 27, 2019

--

We will discuss Transformer-XL in this article. It is assumed that you know about Transformers.

A transformer is a self-attention model to process sequential input like RNN but does so parallelly.

Language Modelling

Given a sequence of tokens x=(x₁,x₂,…,xₙ), a language model tries to estimate P(x).

Now, the problem is reduced to estimating each conditional factor only. We can train a NN to encode the context x₁,x₂,…,xₜ to a fixed size hidden state, which is multiplied with word embeddings to get the logits. These logits are passed to the softmax function which yields probability distribution over the next token.

RNN/LSTM and Transformers can be used for language modelling. Both of them try to learn long term dependencies (measured in the number of tokens), but have their pros and cons-

RNN/LSTM- Processes arbitrary length sequences recursively as only a few, fixed-size matrices need to be learned. It suffers from vanishing/exploding gradients even after introducing gated units (deciding how much information to pass through) and gradient clipping (if the norm of the gradient of the weights becomes larger than a threshold, then, reduce the norm of the gradient to the pre-defined threshold).

Transformers- Processes sequences parallelly and can only process a fixed-length segment at a time. This is because, if we use the whole sequence, then we have to operate on very large matrices. This requires a large amount of memory. So, we divide the given sequence into fixed-length segments and process each one of them separately. But, this leads to a limited context (maximum context length=chosen segment length) as the current segment can’t interact with the previous segment. This makes predicting the first few words (for each segment) difficult, as they have no context. This problem is called context fragmentation.

Transformer during training (segment length=4)

During the evaluation, at each step, the transformer takes the same segment length as used during training, but predicts the word for only the last position (of each segment). At the next step, the segment is shifted by one position to the right and the new segment is processed from scratch. This uses the maximum context length used during training and also relieves the model from the context fragmentation issue. But, this process is slow and expensive.

Transformers during evaluation (segment length=4) at t=0, 1 and 2

Transformer-XL

Transformer-XL (extra-long) combines the pros of both of the models. Transformer-XL works like vanilla Transformer but caches the previous segment’s hidden states at every layer and uses it as memory/context for the current segment. This introduces a notion of recurrence and helps the model to learn longer-term dependencies. We operate on every two consecutive segments and compute gradients for the current segment only.

In eq. 1, we concatenate hidden states (along length dimension) at layer n-1 for segments τ and τ+1. SG denotes stop-gradient- we don’t calculate gradients for the τ-th segment.

In eq. 2, we calculate Query, Key and Value using 3 learnable weight matrices W. Query is calculated using the current (τ+1) segment only. Key and Value are calculated using both the current and previous segment.

In eq. 3, we use Query, Key and Value from above, to calculate hidden state for the n-th layer. This is like self-attention in vanilla transformers.

However, note that the hidden states of the n-th layer of the current segment depend on the hidden states of layer n-1 of the previous segment (this is different from the same layer recurrence in RNNs). We need not be restricted to only the previous segment. We can cache as many previous segments as our memory allows. Also, the largest possible dependency length grows linearly with respect to the number of layers as well as the segment length.

During evaluation, we move segment-wise instead of one position at a time (as in vanilla transformer) and predict segment length tokens at a time. Hidden representations from the previous segments are reused.

Relative Positional Encoding

By using segment recurrence we have created a problem with positional encoding. Before looking below, think what the problem might be.

Suppose, we have segment length=4. Let, positional encoding of the current segment be [0,1,2,3]. So, positional encodings of the previous and the current segment combined will be [0,1,2,3,0,1,2,3]. If we are calculating the self-attention score for position 4 in the current segment, then we can’t distinguish between position of words at position 1 in the current segment and at position 1 in the previous segment. Even if we swap those two words, we would get the same attention score resulting in degraded performance.

To counter this problem, we can use relative encodings instead of absolute ones as done in the vanilla Transformer paper.

The positional encodings give us a clue or bias about where to attend. Also, it seems more intuitive and generalizable to define this bias relatively. Hence, we can define a relative positional encoding matrix R ∈ Lxd. R is sinusoidal without learnable parameters. L is the maximum context length and d is the size of the hidden dimension. Rᵢ denotes i-th row of R, indicating encoding for a relative distance of i.

Eqs. 1 and 2 are same as before.

Eq 3. is the novel relative positional encoding scheme.

In eq. 4, we use masked softmax to zero out the attention score of future words. Then, we multiply these scores by Values.

In eq. 5, we use layer normalization and a residual connection.

In eq. 6, we multiply the output from above by a feed-forward network.

Eqs. 4–6 are similar to the equations of vanilla Transformer.

Using both of the above techniques are necessary to achieve SOTA results on both the character-level and word-level language models. Transformer-XL achieved SOTA results following datasets - WikiText-103, enwik8, text8, One Billion Word and Penn Treebank.

Transformer-XL has also been used to generate text. Examples are given at the end of the paper[1].

--

--