Replacing Recurrence with Attention: Improving Large Language Models with Self-Attention

Kjung
Edge Analytics
Published in
11 min readOct 14, 2021
Photo by Brad Stallcup on Unsplash

In a series of previous posts, we described how we used GPT-3, OpenAI’s large language model (LLM), to improve market research for InVibe. Here we talk about some of the innovations behind LLMs, specifically how they are able to:

1. Train more efficiently on very large datasets, and

2. Perform better modeling of long-term dependencies between words.

It turns out that these two things are related to each other, and, in this post, we’ll give some intuition as to why. We’ll start our journey by refreshing our memories of what language models are, and how ideas from machine translation helped make LLMs like GPT-3 possible.

Language models

Language models are just probability distributions over the sequences of words¹ in some language, e.g, English or French. We want language models to assign higher probability to sequences like, “The rain in Spain falls mainly on the plain” than, “Piano koala fish-sticks”. Language models like GPT-3 are autoregressive — that is, they express this probability distribution as the probability of the next word given (the jargon is “conditioned on”) the previous words, but not later words. For example, an autoregressive language model estimates the probability of words that might follow, “The rain in Spain”, and hopefully assigns a higher probability to “falls” than “elephant”. You can imagine building up long sequences of words iteratively (one word at a time), by starting with a prompt like “Four score and seven”, sampling the next word according to the language model, and repeating the process. To put this another way, autoregressive language models try to estimate, “what word comes next” (as opposed to masked language models, which “fill in the blanks”).

Long range dependencies

Consider the two sentences, “The dog told them something was amiss by barking urgently.”, and “The cat told them something was amiss by meowing urgently”. In these sentences, the word that follows “amiss by…” depends a lot on whether the second word was dog or cat. So an autoregressive language model is ideally able to remember whether it saw “dog” or “cat” by the time makes a choice between “barking” or “meowing”. Of course, words can be much farther apart than in this simple example — they may not even be in the same sentence! Instead of dog or cat, the sentence may say “Daisy” or “Luna”, and whether the name refers to a dog or cat is far away indeed.

Yeah, that doesn’t make any sense… Photo by Anusha Barwa on Unsplash

In general, language models should ideally be able to refer back to words that occurred a while ago in the text. In order to explain how LLMs like GPT-3 handle this, we will digress into a related problem — neural machine translation — where the key innovations were developed.

Neural Machine Translation

In 2014, Sutskever et al demonstrated the use of recurrent neural nets (RNNs) for machine translation, e.g., given an English sentence, “The dog told them something was amiss by barking urgently”, produce a French translation, “Le chien leur a dit que quelque chose n’allait pas en aboyant d’urgence.” At the time, machine translation was still dominated by complex phrase based systems, often painstakingly tuned by hand. Sutskever et al showed that you could approach state of the art performance using no specialized linguistic knowledge — you just needed a big matched corpus (a big set of English sentences and their French translations), and let gradient descent do its thing with a general purpose, off the shelf type of neural net — a recurrent neural net with LSTM units². The RNN acts first as an encoder, processing the English sentence one word at a time (it turned out to work a bit better to process the English sentence in reverse order!) and yielding a single embedding (i.e., an ordered list of real numbers) summarizing the entire sentence. The same RNN then acts as a decoder, and outputs the French translation one word at a time, in order.

Adapted from “Neural Machine Translation By Jointly Learning to Align and Translate” by D Bahdanau et al, ICLR 2015.

This was a pretty big deal and it surprised a lot of people that it worked as well as it did! But there was clearly plenty of room for improvement and lots of ideas for how to do it. For instance, the only path for information about the input sentence to inform to the French translation is through a single embedding, and this path gets longer and longer as you generate the translation. And arguably worse, this single embedding is the same size whether the input sentence is a terse declarative statement or a stream of consciousness rambler. Wouldn’t it be great if you could focus on specific parts of the input sentence while generating each French word, and focus on them in arbitrary order? Bahdanau et al presented work that does exactly this. First, they built on previous work that separates the encoder and decoder into separate neural nets. As before, the encoder processes the input sentence word by word, in order (technically, they used a bi-directional RNN, which maintains two RNNs — one that processes words in the forward order and one that processes words in the reverse order), generating an embedding at each input sentence position. Crucially, the decoder could use the entire sequence of embeddings while it is decoding! Bahdanau et al then used a soft attention mechanism that computed how important each input embedding was for generating the current French word. This recipe of a bi-directional encoder coupled with an autoregressive decoder and an attention mechanism, was at the heart of the next generation of machine translation systems, and by 2016 Google Translate had switched to this paradigm for many language pairs!

What does this all have to do with language models? If you think about, the decoder of a neural machine translation (NMT) system has a lot in common with an autoregressive language model. Each outputs words one at a time, in sequence, with the next word depending on previously generated words. Now, in NMT, the decoder has full access to the sequence of embeddings resulting from the input sentence while producing output. Note that the decoder doesn’t have access to its own previous states — it knowledge of previously generated French words must be summarized solely by the its hidden state, which by design is a mixture of information from previous states. But this is arguably not a big handicap since it has access to the entire input sentence embedding sequence. In our example above, by the time it gets to choosing between aboyant, miaulant, and perhaps meuglant, it can refer back directly to the input sequence embedding corresponding to dog or cat, and adjust accordingly.

What about in an autoregressive language model? Well, there isn’t a source sentence per se. How can words we have already output, even if we output them a while ago, influence the next word? Again, it is only through the current hidden state of the RNN. Say we have a language model for English, and we have already output, “The dog told them something was amiss by…”. By the time we get to this point with a regular RNN decoder, the fact we output dog as the second word is a somewhat distant memory, available to us only through the current RNN state and thus mixed up with everything that happened between then and now.

Can’t we use one of those fancy attention mechanisms so the decoder can figure out, oh, the subject of the sentence was dog, so the next word is more likely bark than meow? You can! And in principle, this is not a big deal when we’re generating a sample. In that setting, we have to go ahead and commit to a specific choice (or maintain a set of sequences of choices for, e.g., beam search) for each word as we go. One can imagine a scheme that allows the decoder to look at embeddings for previous states or for the previously generated words themselves to help decide what words are likely to come next. This process is intrinsically sequential, and there isn’t really any way to avoid that during generation (at least, that is widely used!), so this sort of attention mechanism is not that big a deal to bolt on top. In particular, we don’t have to work at all to make sure we don’t somehow “cheat” by letting the decoder peek ahead at future words — they haven’t been sampled yet!

But recall that one of our desiderata for LLMs was training efficiently on really large datasets. What happens during training for an RNN based model? Even though training doesn’t involve committing to a specific sample (i.e., losses are computed from the distributions over possible output words vs observed words in a given input sentence, NOT with respect to a specific choice of output words), RNNs are intrinsically sequential, and must process input sentences one word at a time, in order. Tricks like using a bi-directional RNN make matters even worse — you would have to run the reverse RNN many times per sentence, each from a different starting point (from the current position back to the first)!

Attention is all you need!

In 2017, Viswani et al presented a a new approach to machine translation that didn’t use RNNs at all — the Transformer. This novel architecture introduced a lot of building blocks that have since been used for all sorts of problems, including image classification! You could spend a lot of time delving into the details of the Transformer, but for this post, we are most interested in two critical properties:

  1. Unlike RNNs, which must process the input sequence sequentially (one word at a time, in forward and/or reverse order), the Transformer calculates embeddings for all words in the input sequence in parallel during training. During sampling, when we are outputing a concrete sentence, this doesn’t hold, but that’s okay — we aren’t as concerned with efficiency at sampling time.
  2. The embedding at position k is influenced by the embeddings for the other positions solely through self-attention⁵. This influence is direct in the sense that the number of steps it takes for information to flow from one embedding to another is independent of how far apart they are in the sequence — in our English to French example, “barking” is no further from “dog” than it is from “amiss”.
The flow of information in an RNN vs in a Transformer decoder block. In the RNN, “dog” informs “barking” through the sequence of hidden states, and is mixed up with the hidden states between the two positions. In a transformer decoder, “barking” can be informed directly by “dog” (solid red line). Note that the flow of information in the transformer decoder is autoregressive! Adapted from “Why Self-Attention? A Targeted Evaluation of Neural Machine Translation Architectures” by Tang et al, EMNLP 2018.

These two properties together make training on very large datasets much more efficient than with RNN based methods. How they interact is a bit subtle. The Transformer, like previous MT systems, uses an encoder and a decoder. While the encoder module is allowed to “see” the entire English sentence, and the decoder can see the entire sequence of embeddings produced by the encoder, producing the French translation is still inherently auto-regressive; words are generated one at a time, in sequence, and at each stage the decoder can’t be allowed to cheat by peeking ahead at future words. Because of the second property, we can achieve this very simply during training, by using an “attention mask” — since the only way for embeddings at different positions can influence each other is through attention, we can enforce the autoregressive property of the decoder by multiplying the information from later positions by zero (masking it) when generating the word at position k, all without breaking our ability to compute everything in parallel. Finally, to bring this back to language models, we strip away the encoder part — the language model is just the decoder part of the transformer.

It’s worth pausing at to reiterate the really cool trick transformers accomplish, and shed some light on the title of Viswani et al’s paper — Attention is all you need. With RNNs, words influence each other because the hidden state at a given position is a function of previous words (or, in the case of a reverse RNN, subsequent words). This serializes the computation, and we really want to avoid that. In a transformer or a language model based on the decoder part of a transformer, each position is processed in parallel, and the information interacts with other positions through attention, again in parallel, and with autoregressive-ness enforced during training by attention masks. In other words, we have replaced recurrence with attention!

  1. Note that we are using “word” somewhat loosely here. Technically, LLMs don’t work at the word level, but rather the “token” level, where a token is a variable length sequence of characters. The reason for this is simple — there are lots and lots of words! At the other extreme there are relatively very few characters, but trying to learn LLMs at the character level is hard. So we find a middle ground. The jargon is that we model sequences of “tokens”, not words, but for purposes of this post, we’ll stick with the more familiar, er, word. Similarly, there is no reason to restrict ourselves to sentences per se — but we’ll often refer to input sentences because it is more familiar and less jargony than the more correct and neutral “sequences”.
  2. This is not quite true — there were various tricks that helped performance a lot, like processing the English sentence in reverse order! But compared to the MT systems in use at the time, this was simply far less work.
  3. We keep the example of English to French for simplicity, but this approach is general — it could just as well be German to English, or Chinese to Dutch — as long as you have a big enough matched corpus you can try this approach.
  4. We keep talking about RNNs but of course there is another alternative for dealing with sequences — convolutional neural nets. These models can work pretty well too, but not as well as transformers, so we’ll leave that aside for simplicity here.
  5. “Self-attention” sounds awfully fancy, but the basic idea is simple — the decoder can pay attention to itself and its internal representations instead of just the embeddings of a source language sentence. In fact, since we are doing language modeling instead of translation, it’s not even clear what you could consider the “source sentence”. The key thing is that the auto-regressive property is enforced during training. During sampling, this will take care of itself.

Edge Analytics is a company that specializes in data science, machine learning, and algorithm development both on the edge and in the cloud. We provide end-to-end support throughout a product’s lifecycle, from quick exploratory prototypes to production-level AI/ML algorithms. We partner with our clients, who range from Fortune 500 companies to innovative startups, to turn their ideas into reality. Have a hard problem in mind? Get in touch at info@edgeanalytics.io.

--

--