[Part-2] Which Attention(architecture) do you need?

Venkata Dikshit
ETHER Labs
Published in
7 min readAug 13, 2019

At EtherLabs, we build AI-first collaboration products. Checkout EtherMeet, AI-based video conferencing service for teams who use Slack.

At EtherMeet (www.etherlabs.io), we apply advanced AI to make video conferences smarter and more relevant.
Attention Architectures at EtherLabs

Part-1 of this two-part series covered the GPT and BERT based architectures — upper two quadrants in [fig-1]. This post covers the architectures from the lower quadrants — XLNet, MASS and GPT-2 (yes, again) to better understand all the hype around it.

[fig-1] Transformer-based models segregated by their architectural design

GPT-2 and XLNet

These two are current state-of-the-art across multiple language tasks and notorious for their data and resource consumption — [1], [2]. These two architectures address some of the limitations of their predecessors — GPT and BERT respectively.

Before we delve into the details of these architectures, we need to be aware of some of the technical terms:

Auto-regression Language modeling — A classical language modeling approach, where the objective is to maximize the likelihood of a sequence of words under the forward autoregressive factorization. GPT is based on left-to-right auto-regression factorization

Auto-encoder Language modeling — Aims to reconstruct the input sequence of text and in the process learning the likelihood of word sequences. BERT is based on a variant of auto-encoder called Denoising auto-encoder where the sequence is reconstructed from the corrupted version of the actual input sequence. BERT uses a modified AE approach that predicts only the masked words instead of the entire sequence

Permutation Language modeling — A variant of auto-regression approach where instead of using only the right-to-left and/or right-to-left factorization, we use all combination of word sequences possible ie. for a sequence x of length T, there are T! different orders to perform a valid autoregressive factorization. This is an efficient way of encoding deep bi-directional representations of token sequences.

[fig-2] Using different factorization sequences as mentioned in the XLNet paper

For example, consider a sequence of length 4 and predicting the token x3 given the same input sequence x but under different factorization orders. [fig-2] shows the combinations in which representations of x3 can be learnt

XLNet — Uses Transformer-XL based pre-training paradigm that addresses the following limitations of AR and AE based approaches:

Independence Assumption — BERT inherently assumes conditional independence between all the masked tokens as they are separately reconstructed thereby limiting its learning capabilities. In comparison, AR-based models factorize the sequence probabilities using product rule that holds universally without such an independence assumption

Input noise — The input to BERT contains artificial symbols like [MASK] that never occur in downstream tasks, which creates a pretrain-finetune discrepancy. Replacing [MASK] with original tokens as mentioned in the original paper does not solve the problem because original tokens can be only used with a small probability. In comparison, AR language modelling does not rely on any input corruption and does not suffer from this issue

Context dependency — The AR representation hθ(x1:t−1) is only conditioned on the tokens up to position t (i.e. tokens to the left), while the BERT representation Hθ(x)t has access to the contextual information on both sides. As a result, the BERT objective allows the model to be pretrained to better capture the bidirectional context

XLNet brings in the best of both worlds — bi-directional capabilities of BERT and fully dependent factorization of AR models.

XLNet Architecture details

  • XLNet uses permutation language model — during permutation, only the factorization order is permuted and not the sequence — we use the same positional encodings of the original sequence instead of calculating them for permuted sequence. Check out section-3.5 in “Attention is all you need” for details on positional embeddings
  • Target-aware representations — giving XLNet objective additional information about which position it will predict. This is achieved using the Two-Stream Self-Attention as discussed in the paper
XLNet base Objective function with the target position
  • Adaptations from Transformer-XL architecture — The key adaptation from Transformer-XL is the recurrence mechanism into the proposed permutation setting and enable the model to reuse hidden states from previous segments. Suppose we have two segments taken from a long sequence s; i.e., x˜ = s1:T and x = sT +1:2T . Let z˜ and z be permutations of [1 · · · T] and [T + 1 · · · 2T] respectively. Then, based on the permutation z˜, we process the first segment and then cache the obtained content representations h˜(m) for each layer m. Then, for the next segment x, the attention update with memory can be written as:

This allows caching and reusing the memory without knowing the factorization order of the previous segment. In expectation, the model learns to utilize the memory overall factorization orders of the last segment.

GPT-2 — Architecturally, GPT-2 is not a major deviation from its predecessor. In the discussion to follow, we highlight some of the notable aspects of GPT-2

  • It uses WebText as the pre-training text — a collection of all outbound Reddit links with at least three karma. The input token representations are Byte Pair Encodings which effectively interpolates between word level inputs for frequent symbol sequences and character level inputs for infrequent symbol sequences. With this approach, GPT-2 gets rid of conventional Language Modelling pre-processing steps such as lowercasing, tokenization, and out-of-vocabulary tokens which restrict the space of model-able strings. This allows GPT/GPT-2 to be evaluated on any dataset regardless of pre-processing, tokenization, or vocab size.
  • GPT-2 has an expanded vocab compared to GPT and is trained with a context window of 1024 instead of 512 and a larger batch size of 512 — tons of GPUs!!
[table-1] GPT-2 architecture experiments with different model sizes
  • [table-1] shows models of 4 different sizes. The final GPT-2 architecture is a 48 layer transformer architecture with feature dimensions of 1600 and has ~5X more parameters compared to its predecessor
  • GPT-2 hasn’t reported the fine-tuning performance on various datasets instead reported the zero-shot performance on eight datasets out of which it has improved the state-of-the-art in seven tasks. This establishes the baseline for GPT-2 performance and any dataset/task fine-tuning will only improve the performance.
  • GPT-2 fails to address the inefficiencies of its uni-directional representations that fail to capture the full context of the text. So, the reported success of GPT-2 should be attributed to the choice of data sources and the ~1.5 B learnable parameters
Image Source: https://blog.floydhub.com/gpt2/

MASS

Masked Sequence to Sequence Pre-training for Language Generation, proposes an extended version of BERT where the objective is modified to predict a series of tokens (aka. fragments) instead of predicting the random tokens as proposed in BERT.

MASS Encoder-Decoder framework
  • MASS objective makes it outperform most of the current approaches in language generation tasks — by jointly training the encoder and decoder
  • The MASS encoder is forced to understand the meaning of unmasked tokens in the context of generating them during the decoder phase
  • During decoding, the unmasked tokens from the encoder are masked thereby forcing the decoder to predict the encoder masked tokens only from the encoder output instead of relying on the previous token (not part of the fragment)
  • MASS addresses the masked token conditional independence of BERT but fails to achieve the deep bi-directional contextuality as stated in XLNet

Conclusion

In this two-part series, we have covered various Transformer based architectures, highlighting their architectural difference, training objects and other significant aspects. These architectures have pushed the limits of text understanding engines and fast heading towards generalized understanding and enabling faster task-specific adaptation through effective pre-training strategies. More to come on Language Models, NLP, Geometric Deep Learning, Knowledge Graphs, contextual search and recommendations. Stay tuned!!

Checkout EtherMeet, an AI-enabled video conferencing service for teams who use Slack.

Sign up at etherlabs.io

--

--