Paper Summary: Attention Is All You Need

Mike Plotz
5 min readNov 19, 2018

--

Part of the series A Month of Machine Learning Paper Summaries. Originally posted here on 2018/11/18.

Attention Is All You Need (2017) https://arxiv.org/abs/1706.03762 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin

If attention is all you need, this paper certainly got enough of it. All this fancy recurrent convolutional NLP stuff? Turns out it’s all a waste. Just point your Transformer’s monstrous multi-headed attention at your text instead. (Did that make any sense? Probably not.) Anyway, I’m excited about this one, because I tried grokking it a few months ago and bounced off, so now I’m back for more.

Like Michelangelo, the authors carved away all the non-Transformer marble from the statue that is the Transformer architecture, leaving only the divinely inspired latent structure beneath. Attention in NLP of course is nothing new (see e.g. Bahdanau 2014), but is mostly combined with RNNs which are complex(ish), tricky to train and regularize (though there’s been lots of work on this), and the clincher, hard to parallelize. Convolutional approaches are sometimes effective, and I haven’t talked about them as much, but they tend to be memory-intensive. Plus we’d like to have the shortest possible path through the network between any two input-output locations. Transformer does this.

The architecture is pretty simple, but I had trouble understanding all the details when I looked at the paper a couple months ago, maybe because the paper is rather terse. So I’ll try to summon my past self and explain it like I wanted it to be explained, though I’ll leave out some details like exactly where and how much dropout is added — you’ll have to read the paper or the code for that. For reference, here’s the high-level architecture diagram:

Some of those boxes are a bit complicated (which we’ll get to), but first an overview. The encoder is on the left and the decoder is on the right, each is divided into N = 6 layers (so, the gray boxes are actually stacked 6 high), and each layer has some sublayers. Each sublayer has a residual connection, followed by layer norm. So far so easy.

There are three components worth diving into: the multi-head attention (orange), the position-wise feed-forward networks (light blue), and the positional encoding. The attention parts are the most complicated and confusing (plus I hear they’re all you need…), so let’s tackle those first.

A diagram:

The style of attention is scaled dot-product attention, which is a bit different from the “additive attention” in Bahdanau 2014, but conceptually similar and faster (because optimized matrix math). The idea is that we have a conditioning signal or query that is applied to a set of key-value pairs — the query and key interact somehow, producing some normalized weights. And these weights are applied to the value, producing a weighted sum.

“Interact somehow” here means dot product, followed by a scaling factor of sqrt(dim(key)), and normalized with softmax. The queries, keys, and values are packed into matrices, so the dot products and weighted sums become matrix multiplies. To keep the architecture simple (and to make the residual connections make sense), all dimensions are 512.

(Why scaled? Because, the authors speculate, the query-key dot products get big, causing gradients in the softmax to underflow.)

What about the multi-headedness? The idea is that we’d like to focus on a bunch of places at once, kind of like how when you read text you fix your fovea at several different locations sequentially. Since there are no timesteps, the only way to do this is with multiple eyes. Heads. Something like that. The authors used h = 8 heads (see below), projecting each 512-dimension key, value, and query down to 64 dimensions with separate learnable projections. The outputs are concatenated and projected again. This ends up having similar computational cost to a single unprojected head.

And masked multi-headed attention? Yeah, that’s important too. On the decoder side we don’t want information about future output words to leak into the network, so they get masked out to -∞ just before the softmax (the sharp-eyed will have noticed the pink “Mask (opt.)” box in the scaled dot-product attention diagram).

It’s also worth scrolling back up to take a close look at where the multi-head attention inputs come from — e.g. the second decoder attention block takes its keys and values from the encoder outputs. Also note that the keys and values are always the same — not strictly true since they get projected differently, but they always come from the same source.

Moving along. There are two ways to think of the position-wise feed-forward networks. They’re either a two layer fully connected network with ReLU applied at each location. Or (and I like this better) they’re actually two 1-kernel-size convolutions applied across position-space: conv → ReLU → conv. The hidden dimension is 2048. You might ask why these sublayers are here. As might I: I don’t have a good intuition for this.

And positional encodings. We have to inject position information somehow, so the authors decide to use fixed sinusoids of different frequencies that get added directly to the input embeddings. Kind of like a Fourier transform. Learned positional encodings also work, but the authors hope that this might improve generalization to longer sequences. In any case, this is pretty clever — it allows easy modeling of relative positions with linear functions.

Results: works real good. The large model does take 3.5 days to train on 8 P100s, which is a bit beefy. Fortunately the small model (~4 GPU-days) is competitive. Lots more details on training, by the way, including a form of regularization called label smoothing that I hadn’t heard of (the idea: don’t use probabilities of 0 and 1 for your labels, which seems eminently reasonable to me). There’s also a learning rate schedule that has a warmup period sort of like ULMFiT’s, though I think for different reasons.

--

--

Mike Plotz

yet another bay area software engineer • learning junkie • searching for the right level of meta • also pie