XLNet

Shaurya Goel
6 min readAug 5, 2019

--

We will discuss XLNet in this article. It is assumed that you know about Transformers, Transformer-XL and BERT.

There have been many methods to do unsupervised representation learning in the domain of NLP. These methods first pre-train a NN model on large unlabelled corpora and then fine-tune on downstream tasks. There are two successful pre-training methods- Autoregressive(AR) language modelling and Autoencoding(AE).

  1. AR- We seek to estimate the probability distribution of a text sequence. This is equivalent to estimating the token at time t given tokens from 1 to (t-1) or given tokens from (t+1) to T, where T is the length of the sequence (why? see language modelling). But, it does not model bidirectional context which is required in various NLP tasks like QA.

2. AE- We corrupt few input tokens by replacing them with a [MASK] token and train our model to recover the original tokens from masked tokens. A notable example is BERT. BERT uses context from both directions and thus leads to better performance than AR language modelling. But, these masked tokens appear during pre-training only and not during fine-tuning. This creates a pretrain-finetune discrepancy. Also, BERT cannot model the joint probability as AR does. This is because BERT assumes that the predicted tokens are independent of each other given unmasked tokens. This is not true as the natural language has many dependencies.

Authors proposed XLNet, a generalized autoregressive method that combines best of both of the AR and AE while avoiding their limitations.

Permutation Language Modelling Objective

Is there a pre-training objective which takes advantage of both of the pre-training methods? Authors proposed a permutation language modelling objective that allows AR models to capture bidirectional context. Consider a sequence x of length T. There are T! different ways to do valid autoregressive factorization. If we train our model on many such factorization orders in an AR way, in expectation our model will learn to capture context from both sides (bidirectional context). As this uses AR framework, it avoids the pretrain-finetune discrepancy and the independence assumption of BERT.

Permutation language modelling objective for predicting x₃ given the same input sequence x but with different factorization orders

The objective function is-

where Zₜ is the set of all possible permutations of length t, zₜ denotes t-th element and z<t denotes first t-1 elements of the permutation z ∈ Zₜ

NOTE: The proposed objective only permutes the factorization order, not the sequence order. We rely on proper attention masks and positional encodings to achieve permutation of the factorization order.

But, there is a problem. Suppose there are two factorization orders: z₁=1->3->4->2 and z₂=1->3->2->4. In both of these orders, first and second positions have the same values. If we want to predict the word at third position, we would take the same context at position 1 and 2 in both the orders. This would produce the same distribution of words in both cases regardless of the target position. So, we take the context and the target position and re-parameterise the next-token distribution to be target position aware as:

Two-Stream Self-Attention

There are two contradictory requirements for g-

  1. To predict token at position zₜ, g should only use the position zₜ and not the content at zₜ. Otherwise, the objective becomes trivial.
  2. To predict tokens at positions greater than zₜ, g should also encode content at the current position to provide full contextual information.

We can resolve the above conflict by using two sets of hidden representations-

  1. Content stream representation h- encodes context and content at the position zₜ. Initialized for the first layer by the corresponding word embedding at zₜ.
  2. Query stream representation g- encodes context and position zₜ, but not the content at zₜ. Initialized for the first layer by a learnable vector w.

Q, K and V denote query, key and value respectively. The update is like the self-attention in Transformers. Multi-head attention, residual connections, layer normalization and position-wise feed-forward are not written to avoid clutter. g from the last layer is used in the proposed objective.

Using Transformer XL

As we are using an AR framework, we can incorporate the state-of-the-art AR language model, Transformer-XL into the above framework. Transformer-XL introduced two techniques- relative positional encoding and segment recurrence mechanism.

  • Relative positional encoding- we apply relative encoding to only previous positions (<t) in the factorization order.
  • Segment recurrence mechanism- suppose we have two segments x̃ and x of length T each, with permutations z̃ and z respectively. We process the first segment according to z̃ and cache the hidden states h̃ for every layer. For segment x, we can write the update equation as-

where […] denotes concatenation along the sequence dimension. We see that we don’t have to know the factorization order of the previous segment(s). We can compute the query stream similarly.

(a) Content stream attention, (b) Query stream attention and (c) Overview of permutation language modelling with two-stream self-attention

In the above figure, (c) has attention masks. White denotes a 0 and red denotes a 1. Given a factorization order 3->2->4->1, we first predict the word at position 3 using no context (by masking 2, 4 and 1). Then we predict the word at position 2 given the word at position 3 as context and masking 4 and 1. Similarly, we can find words in positions 4 and 1.

Multiple Sentence Segments

Many downstream tasks have multiple sentence segments, e.g. a question and context paragraph in question answering. Similar to BERT, we take two sentence segments- A and B. 50% of the time B follows A in the corpus and 50% of the time B is a random sentence. We apply permutation language modelling to the concatenation of A and B. We reuse the hidden states of A only if B followed A. Input to the model is similar to BERT- [A, SEP, B, SEP, CLS], where SEP and CLS are special symbols and A and B are sentence segments.

Relative Segment Encoding

BERT adds absolute segment embedding to the word embedding at each position of the segment to distinguish between sentence segments. We can use the idea of relative positional encoding in Transformer-XL to the sentences too. Given a pair of positions i and j, if they are from the same segment we can use segment encoding s⁺ else s⁻, where s⁺ and s⁻ are learnable model parameters. When i attends to j, we compute s (equal to s⁺ or s⁻ depending on whether i and j belong to the same segment or not). Then, we compute attention weight by taking the dot product of query qᵢ and s and add this weight to the standard self-attention weight.

The proposed XLNet achieves SOTA results on 18 tasks- 7 GLUE tasks, 3 reading comprehension tasks including SQuAD and RACE, 7 text classification tasks including Yelp and IMDB, and the ClueWeb09-B document ranking task.

--

--