The FLOPs Calculus of Language Model Training
Extremely large language models like the famous GPT-3 by OpenAI are all the rage. Many of us are now trying to get a sense of scale of the compute that goes into training them.
In this article, I will offer you a very useful tool to reason about large Transformer LMs. This tool will help you roughly answer questions like
- How much does it cost to train GPT-3?
- How long will training this big model take me?
Turns out quick back-of-the-envelope calculations can be sufficient to answer these questions if you use a simple equation that ties together
- the compute required to train a Transformer model ( C )
- its number of parameters, or model size ( N )
- the number of tokens that the model is trained on ( D )
Without further ado, meet the Transformer FLOPs Equation:
C ≈ 6ND.
A slightly more sophisticated version of the equation expresses the compute C as the product of cluster’s throughput 𝜏 and training time T:
𝜏T = 6ND.
Example
Let’s apply the Transformer FLOPs Equation to some middle-school style problem solving:
Problem An 82B parameter Korean variant of GPT-3 called HyperCLOVA was trained on 150B tokens using a cluster of 1024 Nvidia A100 GPUs. How long could that take?
Solution The peak float16 FLOPs throughput of A100 is 𝜏 = 312 teraFLOPs = 3.12e14 FLOPs. The total compute is C = 6 ∙ 8.2e10 ∙ 1.5e11 = 7.38e22. The training must have taken at least T = C / 1024𝜏 / 86400 = 2.67 days.
Answer Validation According to the white paper, training took 13.4 days. Our estimate is 5 times off, but we did get the order of magnitude right.
As I explain later in this post, the error is due to us naively plugging the theoretical peak throughput 𝜏 that is not achievable with distributed training and when models do anything but large matrix multiplications. If you correct 𝜏 accordingly (I will discuss this later), the FLOPs equation will get much more accurate. The other correction is that with checkpointing² that is a must for the largest models the required compute C goes up to ≈ 8ND.
This equation is not something I came up with¹: you can find it in both the scaling laws and the GPT-3 papers by OpenAI. Yet I think it is not as widely known as it should be. In the rest of the article I will derive the equation and discuss what realistic throughput 𝜏 you can expect. I will assume you are familiar with Transformers (check out the paper or a blog post like e.g. this one if you are not).
Derivation of Transformer FLOPs Equation
To derive the Transformer FLOPs equation we will have to make a key assumption.
The Weight FLOPs Assumption
The FLOPs that matter the most are weight FLOPs, that is ones performed when intermediate states are multiplied by weight matrices.
The weight FLOPs are the majority of Transformer FLOPs, meaning that we can put aside FLOPs required for bias vector addition, layer normalization, residual connections, non-linearities, softmax and even attention. If you do not believe this, you are not wrong: while other FLOPs are less numerous, they also requires a lot of memory access and will in practice matter quite a bit. I will return to this later.
The beauty of matrix multiplications is that each of them adds a predictable and easy to compute number of FLOPs to the training total:
weight FLOPs for multiplying by a matrix W = 6 times batch size times size of W
This Weight FLOPs Equation can take some time to wrap one’s head around. To understand where it comes from, consider a weight w that connects an input unit i to an output unit j:
For each example in the batch, the weight w generates exactly 6 FLOPs combined in the forward and backward pass:
- The unit i multiplies its output h(i) by w to send it to the unit j.
- The unit j adds the unit i’s contribution to its total input a(j).
- The unit j multiplies the incoming loss gradient dL/da(j) by w to send it back to the unit i.
- The unit i adds the unit j’s contribution to its total loss gradient dL/dh(i).
- The unit j multiplies its loss gradient dL/da(j) by the unit i’s output h(i) to compute the loss gradient dL/dw for the given example.
- (The sneakiest FLOP, IMHO) The weight w adds the contribution from step 5 to its loss gradient accumulator dL/dw that aggregates gradients for all examples.
The Weight FLOPs Equation directly follows from the fact that we need 6 FLOPs per example per weight. And from this equation follows the Transformer FLOPs Equation. To understand this, think about how many weight matrix multiplication Transformer performs for each input token, no matter how many input sequences the batch consists of. The answer is exactly 1 for each weight matrix! So the total number of FLOPs for each token is 6 times the model size N, Q.E.D.
Why only weight FLOPs matter
I have asked you to swallow the assumption that only FLOPs from mulplications by weight matrices, and you might be wondering now if the assumption is too strong. It is a fair question to ask; it is commonly thought that attention is the bottleneck in Transformers, and here I am boldly brushing it aside. The reason we can do this is because attention only adds O(dL) FLOPs³ per token per layer, whereas matrix multiplications add O(d²) (for more precise counting check out the Megatron paper by Nvidia). And when we talk about really large models, d tends to be considerable larger than L, and hence weight FLOPs dwarf attention FLOPs!
The other FLOPs (softmax, layer norm, activations and etc), should be even more negligible, but there is a catch — the GPU memory bandwidth becomes the bottleneck when these operations are performed. In practice these elementwise operations can take non-negligible time. I find it thus helpful to think about the weight FLOPs (WFLOPs) throughput that a particular implementation can deliver on a particular hardware.
Estimating the WFLOPs Throughput
The theory of counting Transformer FLOPs is elegant, but as seen in the HyperCLOVA example, naive application results in significant underestimation of time required to training the language model. For more precise reasoning, we need a better of idea of what actual WFLOPs throughput can be like.
I have done a little case study on a A100 GPU. According to Nvidia documentation, it can deliver up to 312 bfloat16 teraFLOPs — that’s 3.12e14 operations per second! Nvidia docs also show how ~250 teraFLOPs can be actually achieved by doing 4096 x 8192 x 4096 matrix multiplications, and I was able to reproduce that. But what practical throughput can we get when training neural networks?
I have experimented with the GPT-2 implementation from Huggingface Transformers and with a bespoke MLP implementation (note that WFLOPs calculus for an MLP is the same as for Transformer). Both models have states of d=1600 dimensions for each input and d_ff=6400 intermediate units. I train MLP on batches of 8192 input vectors; GPT-2 receives 8192 input tokens as B=32 contexts of length L=256 to get an upper bound on the throughput that is less affected by attention (which should become cheaper for larger models). To fit in the single GPU memory, I use only 8 GPT-2 layers.
Here are the throughputs:
A training step (including backprop) for a linear float16 MLP with no activations yields a throughput of 230 teraWFLOP/s, very close to 237 teraWFLOP/s that I registered for pure matrix multiplications. With float32 weights and mixed precision training (which AFAIK is the standard in training big Transformers these days) the throughput drops to 207.6 teraWFLOPs due to the float32 -> bfloat16 conversion that mixed precision training involves. Adding a ReLU activation and a residual connection causes a further throughput drops to 185.7 teraWFLOPs. This might be surpising but is very much in line with Nvidia performance documentation, which explains how throughput for activations and elementwise operations is bounded by the memory bandwidth.
The MLP throughput looks encouraging, but for the actual GPT-2 implementation from HuggingFace Transformers the throughput was merely 68 teraWFLOP/s. I have not looked deeper into the exact breakdown, but a likely explanation is that the memory-intensive computations, such as residual connections, activations, layer normalization, attention masking and attention softmax do cost a lot when combined together.
WFLOPs throughput in the literature
Throughput estimates can also be obtained by looking at white papers. For example, from here, here and here we can estimate the throughput achieved with various forks of Megatron-LM by Nvidia²:
Note that these numbers are for highly distributed setups, single GPU throughput for Megatron is likely to be much higher, thanks to the extensive used of fused operations.
To Sum It Up
In this article I have shared with you the Transformer FLOPs equation that makes reasoning about extremely large language models easy. The equation ties together the throughput 𝜏, the training time T, the model size N and the number of training tokens T:
𝜏T = 6ND.
Looking at publicly available white papers, the throughput 𝜏 is likely to be anywhere between 50 and 150 teraWFLOP/s per A100 GPU.
My favorite corollary of this equation is that assuming constant throughput, the training time grows linearly with the model size. So if you want to increase the model size by 2, you have to either use 2 times as many GPUs, or wait 2 times longer. Easy mathematics that you can do in your head!
Beyond Transformers
As the art of compressing the internet in matrices (a.k.a. language modeling) develops further, the FLOPs calculus might get hairier. For example, the math is different for Mixture of Experts (MoE) models trained by researchers at Google and Meta. In these models, only a fraction of model’s weights is active for every input token. The equation can thus be fixed by replacing the total model size N with the number of active weights.
The FLOPs calculus for LSTMs would look very similar to that of Transformer, which is a key factor explaining their demise. The total number of FLOPs grows linearly with the model size. But to train LSTMs on long sequences while fitting in GPU memory one has to reduce the batch size. And with a small batch size the GPU throughput for sequential LSTM computations falls dramatically. For example, for 32x1600x6400 matrix multiplication the throughput is below 20 teraFLOP/s, more than 10 times slower than for 8192x1600x6400! Recurrence comes at a price: the computation for further tokens must wait before computations for previous tokens are done, making computations less parallel and thus less GPU-friendly.
The End
I hope you found this article useful! Many thanks to Harm de Vries, Amine el Hattami, Torsten Scholak, Nicolas Chapados, Sebastien Paquet, and my other fabulous colleagues at ServiceNow Research for discussions that greatly helped me in researching this topic.
Footnotes
[1] Disclaimer: this is meant to be a popular article, not an academic contribution. While I’m trying to give credit where the credit is due, please don’t freak out if you find this not sufficiently rigorous and just get in touch — I will to try to address the issue.
[2] Note that Nvidia reports different teraFLOP/s numbers, namely 138 teraFLOP/s vs 100.8 teraWFLOPs that I calculated. The major source of the difference is that they include the FLOPs needed for the extra forward pass that recomputes the activations. Activation recomputation (checkpointing) allows back-propagation without storing all intermediate states in memory. It is now routinely used to train the largest models. The extra forward pass requires 2ND FLOPs. The rest of the difference comes from them including attention and output layer FLOPs. In this article I view that activation recomputation FLOPs do not directly contribute to learning and thus reduce the system’s effective WFLOPs throughput. But if you are an HPC person and what you want to showcase is optimal device utilization, it makes sense to include these FLOPs in the total.