Understanding XLNet

Shweta Baranwal
The Startup
Published in
9 min readMay 30, 2020

--

Unsupervised representation learning has been highly successful in the NLP domain. Models are pre-trained on the unsupervised tasks so they could perform well on the downstream tasks without starting from scratch. Traditional language models used to train networks left-to-right to predict next word in the sequence, clearly lacking the “bidirectional context”, then ELMo tried to capture it but it was two separate models (L to R, R to L) concatenated together, therefore not simultaneously bidirectional. Then BERT showed up with the capability of both bidirectional context training and parallelization, showing state-of-the-art results on almost every NLP datasets. But BERT suffers from few discrepancies and was outperformed by XLNet, which we will discuss in the next sections.

Unsupervised pre-training methods:

Here we will discuss 2 methods of pre-training objectives.

  1. Auto-regressive (AR) language modeling:

AR language modeling seeks to estimate the probability distribution of a text corpus with an autoregressive model. Given a text sequence x = [x1,… ,xT ], AR language modeling performs pre-training by maximizing the likelihood under the forward autoregressive factorization:

where h(X(1:t-1)) denotes context representation produced by neural models till index t-1 and e(X(t)) is the embedding of word at index t.

2. Auto-encoding (AE) language modeling:

AE model does not perform explicit density function estimation like in AR but reconstruct the original data from corrupted input. BERT is an example for AE approach where Masked Language Model (MLM) is performed on the original sequence. Model masks 15% of the tokens at random with [MASK] token and then predict those masked tokens at the output layer.

Let the original sequence is x = [x1, x2,…,xT] and ^x is the corrupted version with few tokens marked [MASK] and let those few masked token be x_bar. The objective is to reconstruct x_bar with ^x:

where m(t)= 1 indicates x(t) is masked, and H is a Transformer.

Pros and Cons of AR and AE modeling:

Context dependency: Since an AR language model is only trained to encode a uni-directional context (either forward or backward), it is not effective at modeling deep bidirectional contexts, while BERT access the contextual information from both sides.

Independence Assumption: BERT assumes that all masked tokens are separately reconstructed i.e. masked tokens are independent of each other whereas AR language modeling objective factorizes using the product rule that holds universally without such an independence assumption. Let’s consider an example [New, York, is, a, city]. Let “New” and “York” are masked words. Then BERT objective is:

Input noise: The input to BERT contains artificial symbols like [MASK] that never occur in downstream tasks, which creates a pretrain-finetune discrepancy. In comparison, AR language modeling does not rely on any input corruption and does not suffer from this issue.

XLNet to the rescue:

As we saw both the model approaches have their own pros and cons and we want to get the best of both models’ pre-training methods. XLNet helps achieving that in the following ways:

Permutation Language Modeling:

Permutation based language modeling retains the benefits of AR model and also include bidirectional context. For a sequence of length T, there are T! different orders to perform a valid autoregressive factorization. Since parameters are shared across all factorization orders, the model gathers all the information from context of both sides.

Let Z be the set of all possible permutations of the length-T index sequence [1,…..,T], z(t) and z(<t) denote the tᵗʰ element and the first (t-1) elements of a permutation z ∈ Z , then permutation language modeling objective can be expressed as follows:

This objective fits into the AR framework, it naturally avoids the independence assumption and the pretrain-finetune discrepancy.

Example for Permutation Language Modeling, here t is the iteration of 1 to T.

Permutation LM only permutes the factorization order, not the sequence order. In other words, XLNet keep the original sequence order, use the positional encodings corresponding to the original sequence, and rely on a proper attention mask in Transformers to achieve permutation of the factorization order.

Example:

Permutation mask: (i, j) cell represents whether token i attends to token j

Consider an example, [quick, brown, fox, jumps] and for this sequence assume we have permutation z = [1,3,4,2]. Then the attention/permutation mask will be given as left fig.

XLNet objective for previously mentioned example

What will be the permutation in case of second term of the above example?

Answer: (any permutation sequence of [1,3,4,5], 2)

XLNet always learns more dependency pairs given the same target.

Issue in Permutation Language Modeling:

The standard Transformer parameterization may not work in the case of Permutation LM, let’s see how.

Consider two different permutations [This, great, is] and [This, is, great] and let’s suppose given z(<t), we are predicting next word at tᵗʰ position ( t = 2).

P(Xz(₂) = ‘great’|[This]) → z(2)=3

P(Xz(₂) = ‘is’|[This]) → z(2)=2

Here z(<t) is [This] and now we are predicting next word using standard softmax formulation. Since for both the permutations, z(<t) is same, next word could be anything (“great” or “is”). Therefore, it’s important to know where the next word belongs in the original sequence, in this case 3 and 2 position resp, but the softmax function do not have this information.

Snapshot from XLNet paper

Here, h(Xz(<t)) is coming from Transformer and e(x) is the embedding of target word. Point to notice here is h(Xz(<t)) does not depend on the position it will predict (i.e. z(t), the target position), creating target position unawareness in the model. This means that the model is cut off from knowledge regarding the position of the token it is predicting.

XLNet’s Two-Stream Self-Attention:

To avoid the problem of target unawareness, re-parametrization of next token distribution is required which takes into consideration the position at which prediction is being made along with the bidirectional context.

here, g(Xz(<t), z(t)) denotes a new type of Transformer representation which additionally take the target position z(t) as input. The formulation is such that, g relies on the position z(t) to gather information from the context Xz(<t) through attention.

Now for permutation [This, great, is] while computing P(Xz(t) = ‘great’|[This]), model will have information that here Xz(₂) represents the word at z(2)=3 index through the modified transformer function g.

But this raises 2 contradictions from the standard Transformer architecture:

  1. to predict the token Xz(t) , g(Xz(<t); z(t)) should only use the position z(t) and not the content Xz(t) , otherwise the objective becomes trivial
  2. to predict the other tokens Xz(j) with j > t, g(Xz(<t); z(t)) should also encode the content Xz(t) to provide full contextual information

To resolve such a contradiction, 2 sets of attention are defined instead of one.

  1. Query stream:

Query stream g(Xz(<t); z(t)) have access to the contextual information Xz(<t) and the position z(t), but not the content Xz(t). This is initialized with random weight, i.e. g⁰ᵢ = w, where g⁰ represents query stream at 0ᵗʰ layer (initial).

2. Content stream:

Content stream h(z(≤t) includes both contextual information Xz(<t) and also content at z(t). This is initialized with corresponding word embeddings i.e. h⁰ᵢ = e(xᵢ), where h⁰ represents content stream at 0ᵗʰ layer (initial).

Attention/Permutation masks for query and content stream

For the subsequent layers (m = 1,2,….,M) query and content stream are defined as follows:

where Q, K, V denote the query, key, and value in an attention operation. The query stream of last layer is used to compute the softmax for next word prediction.

mems: explained in later sections

Segment recurrence mechanism:

Transformers take fixed length of the input sequence, there is an upper limit on distance of relationship that a transformer can model. There are many scenarios where it may be required to feed long sequences to model. In those cases vanilla transformers fails. XLNet derives the solution for this problem from Transformer-XL.

For example, if we had the consecutive sentences

“I went to the store. I bought some cookies.”

we can feed “I went to the store.” first, cache the outputs of the intermediate layers, then feed the sentence “I bought some cookies.” and the cached outputs into the model.

In the above 2 figures, you might have noticed [mem] hidden unit, that is the cached content stream saved from each of the transformer layer of the first sequence and are re-used while modeling second sequence.

The cached content stream (h) is appended into the second sequence’s key and values for computing attention for both query and content stream.

h tilda is the cached content stream for layer m-1 of the first sequence

This allows caching and reusing the memory without knowing the factorization order of the previous sequence.

Relative positional and segment embeddings:

XLNet derived the concept of relative positioning and segmentation from Transformer XL.

Relative positioning: Instead of absolute position embedding like in BERT, XLNet uses relative position embedding. Instead of adding position embedding to word embedding, attention scores are computed on the basis of positions and added to the content/query’s attention score. When a query vector q(i) attends on the key vectors k(<i), it does not need to know the absolute position of each key vector to identify the temporal order of the segment. Instead, it suffices to know the relative distance between each key vector k(j) and itself q(i), i.e. i-j.

Relative segmentation: In BERT, specific embeddings are assigned for each segment, whereas XLNet learns an embedding which represents whether a pair of positions belongs to one segment or not. Given a pair of positions i and j in the sequence, if i and j are from the same segment, we use a segment encoding s(i,j) = s+ or otherwise s(i,j) = s-, where s+ and s- are learnable model parameters. Such a method can handle more than two input segments. Similar to relative position embedding, this embedding is used during attention computation between any two words.

Attention score is sum of all the three components

End note:

XLNet performs consistently better than BERT on almost all datasets. XLNet can also be applied on various tasks just like BERT and Hugging Face provides pre-trained version of model as well. Checkout this.

--

--