XLNET at a quick glance — to begin with !!

Aishwarya V Srinivasan
6 min readAug 1, 2019

--

In this article, I am excited to take you through the most recently published Natural Language Understanding algorithm by Google Brain and CMU — XLNET. This algorithm is a breakthrough in NLP as it outperforms the state-of-the art BERT algorithm in 20 different tasks. Come on, let us explore what this new algorithm got to offer to the world !!

This article is structured as follows —

  • Brief through BERT,
  • Understand BERT’s shortcomings
  • Look into transformer architecture
  • Dive into XLNET

BERT Architecture

BERT stands for “Bidirectional Encoder Representations from Transformers”. It is a neural network architecture that can model bidirectional contexts in text data using Transformer.

What is bidirectional?

  • Traditional methods predict the current token given previous “n” tokens, or predict the current token given all tokens after it.
  • No method takes previous and next tokens at a time when predicting the current token.
  • BERT overcomes this shortcoming, in that it considers previous and next tokens to predict the current token. This property is coined as “bidirectional”.

Bidirectionality is achieved by a phenomenon called “Masked Language Modeling”. The model is pre-trained and can be used for a suite of token and sentence level tasks.

What is Masked Language Modeling (MLM)?

In order to achieve bidirectional representation , 15% of tokens in the input sentence are masked at random. The transformer is trained to predict the masked words. For example, consider the sentence — “The cat sat on the wall”. The input to BERT would be “The cat [MASK] on the [MASK]”.

BERT is also suited for Next Sentence Prediction tasks in that, it is pre-trained with pairs of sentences wherein 50% of the times the sentence B follows A and other times it doesn’t. Once the model is pre-trained, it is fine tuned for specific applications such as question answering, sentence completion, check semantic equivalence of two sentences etc.

Disadvantages of BERT

There are two main limitations of BERT algorithm. They are

  1. BERT corrupts the input with masks and suffers from pretrain-finetune discrepancy. In real life applications, we do not have inputs that are masked. How BERT handles it in reality remains ambiguous.
  2. BERT neglects the dependency between masked positions. For example, consider the sentence “New York is a city” and input to BERT to be “[MASK] [MASK] is a city”. The objective of BERT would be

log p(New | is a city) + log p(York | is a city)

From the above function, it is clear that there is no dependency between learning “New” and “York”. So, BERT can result in a prediction like “New Francisco is a city”.

Transformer Architecture

XLNET integrates ideas from Transformer-XL, the state-of-the-art autoregressive model into pretraining. Transformer is a model used for language translation purposes by google. It basically revolves around “attention”. It is an encoder-decoder model where you map one sequence to another — English to French. To translate a sentence in English to French, the decoder needs to look at the entire sentence to selectively extract information from it at any point in time (because the order of tokens in English need not be the same in French). So, all the hidden states of the encoder are made available to the decoder.

How does the decoder know which hidden state it should look up at any point?

It is by weighting each of the hidden states of the encoder. The weights are determined by a simple feed forward neural network. These are called attention weights, or values in the terminology of the paper. Here is a link to a wonderful explanation of attention weights. Some of the terminologies used in the paper are

  • Query (Q) — decoder’s hidden state.
  • Keys (K) — encoder’s hidden states.
  • Values (V) — attention weights when processing a query.

Two important things are integrated from Transformer-XL into XLNET.

  • Positional Encoding — keep track of the position of each token in a sequence (will know why we have this in the later sections)
  • Segment recurrence — cache the hidden state of first segment in memory in each layer and update attention accordingly. It allows reuse of memory for each segment.

Now, we are good to delve into XLNET :)

XLNET — Generalized Auto-Regressive model for NLU

XLNET is a generalized autoregressive model where next token is dependent on all previous tokens. XLNET is “generalized” because it captures bi-directional context by means of a mechanism called “permutation language modeling”. It integrates the idea of auto-regressive models and bi-directional context modeling, yet overcoming the disadvantages of BERT. It outperforms BERT on 20 tasks, often by a large margin in tasks such as question answering, natural language inference, sentiment analysis, and document ranking.

Permutation Language Modeling (PLM)

PLM is the idea of capturing bidirectional context by training an autoregressive model on all possible permutation of words in a sentence. Instead of fixed left-right or right-left modeling, XLNET maximizes expected log likelihood over all possible permutations of the sequence. In expectation, each position learns to utilize contextual information from all positions thereby capturing bidirectional context. No [MASK] is needed and input data need not be corrupted.

The above diagram illustrates PLM. Let us consider that we are learning x3 (the token at the 3rd position in the sentence). PLM trains an autoregressive model with various permutations of the tokens in the sentence, so that in the end of all such permutations, we would have learnt x3, given all other words in the sentence. In the above illustration, we can see that the next layer takes as inputs only the tokens preceding x3 in the permutation sequence. This way, autoregression is also achieved.

Comparison between XLNET and BERT

For example, consider the line “New York is a city” and that we need to predict “New York”. Let us assume that the current permutation is

BERT would predict the tokens 4 and 5 independent of each other. Whereas, XLNET, being an autoregressive model, predicts in the order of the sequence. i.e., first predicts token 4 and then predicts token 5.

In this case, XLNET would compute

log P(New | is a city) + log P(York | New, is a city)

whereas BERT would reduce to

log P(New | is a city) + log P(York | is a city)

How is XLNET implemented using Transformers?

We saw that transformer looks at the hidden representation of the entire sentence to make predictions. To implement XLNET, the transformer is tweaked to look only at the hidden representation of tokens preceding the token to be predicted. Recollect that, we embed the positional information for every token when fed into the model. Suppose if token 3 is to be predicted, the subsequent layers

  • cannot access the content of token 3 from the input layer.
  • can only access the content of tokens preceding it and only the positional information of token 3.

The Q, K and V are updated according to the above principle when computing the attention.

Experiments conducted with XLNET

Following are the datasets XLNET was tested by the authors of the paper.

  • RACE dataset — 100K questions from English exams, XLNET outperforms the best model by 7.6 points in accuracy.
  • SQuAD — reading comprehension tasks — XLNET outperforms BERT by 7 points.
  • Text Classification — Significantly outperforms BERT on variety of datasets (see paper for more details).
  • GLUE Dataset — consists of 9 NLU tasks — Figures reported on paper, XLNET outperforms BERT.
  • ClueWeb09-B Dataset — used to evaluate the performance of document ranking, XLNET outperforms BERT.

Conclusion

I hope you enjoyed reading this blog. If you have any questions, please do post them below.

--

--