An intuitive understanding of mixed precision training

Shrish
7 min readMay 6, 2023

--

Photo by Andrea De Santis on Unsplash

Deep learning models are well-known for their ability to achieve higher accuracy as their size increases. However, as models grow larger, they also require more memory and computing power to train. This increased computational requirement can result in higher costs, especially when using cloud computing services.

To overcome these challenges, the deep learning community has been exploring different ways to optimize the training process. One of the promising approaches by Sharan Narang is mixed precision training, which has gained popularity in recent years.

While there have been several attempts to train large models with reduced precision, all previous approaches have resulted in a non-trivial loss of accuracy for larger models. Mixed precision training, on the other hand, uses lower-precision data types for certain parts of the neural network while maintaining higher-precision data types for others. By doing so, it reduces the memory requirements of the neural network, speeds up the training process, and reduces costs.

The general idea behind mixed precision training is to carefully choose which parts of the neural network can use lower-precision data types without significantly impacting model accuracy. Although the technical details of mixed precision training can be complex, several resources are available to help you get started, such as the NVIDIA Deep Learning Documentation, which provides a good introduction to mixed precision training and offers guidance on how to implement it using their hardware and software tools.

Moreover, almost all deep learning frameworks, including TensorFlow and PyTorch, have built-in support for mixed precision training, making it easy to implement this technique in your own deep learning projects.

Floating point precision

Computers store numbers using the binary format, which represents numbers using only 0s and 1s. For example, with a single bit, we can represent only two unique numbers: 0 and 1. With two bits, we can represent four unique numbers: 00, 01, 10, and 11. By reserving the first bit to signify the sign of a number (0 for positive and 1 for negative), we can represent integers ranging from -1 to +2.

Using a commonly agreed convention, we can utilize the 4 unique numbers mentioned above to represent 4 unique decimal numbers as well, namely 0, 0.1, 0.2, and 0.3. This convention results in a range between 0 and 0.3 with a precision of 0.1. However, if we decrease the precision to 0.2, we can cover a larger range of numbers, such as 0, 0.2, 0.4, and 0.6. It’s essential to note that we can always represent 4 unique decimal numbers using only 2 bits. Therefore, increasing the precision limits our range, while decreasing it expands it.

Computers use 32 bits to represent a total of 2³² unique numbers, ranging from 0 to 2³²in integers. However, we use a lot of floating-point numbers (decimals) in computers, and it is impossible to represent a decimal number of arbitrary precision. Nonetheless, the limited precision of numbers is sufficient for most practical purposes.

To represent a decimal number in a computer, we use conventions. First, we write the decimal number in scientific notation (e.g., 0.0000001001011110101 is represented as .1001011110101 x 2⁸). Then, we use the first bit to save the sign of the number, followed by eight bits for the exponent(here it is 8) and the next 23 bits for the significand. In this example, the significand is .1001011110101.

Ranges of different floating precision formats(source).

Note: In deep learning use cases, we are usually more concerned about the scale of gradients rather than the precision of numbers. This is where BF16, a number format developed by Google, comes into play. BF16 reserves more space for the exponent of a number at the cost of precision, making it more suitable for deep learning tasks. BF16 is typically deployed on specialized hardware such as TPUs and some of the latest GPUs to achieve optimal performance.

What is mixed precision training?

Using half-precision can significantly speed up the training and inference of deep learning models while saving memory. However, optimizing for speed and memory may come at the cost of model accuracy. To mitigate this issue, mixed precision training provides a hybrid approach that maintains accuracy while still benefiting from the speed gains of half-precision.

In mixed precision training, all tensors and arithmetic for forward and backward passes use reduced precision, such as FP16. No hyper-parameters, such as layer width or gradient clipping, are adjusted, and models trained using these techniques do not experience any loss of accuracy when compared to single-precision baselines.

The approach involves keeping the model parameters (weights) in single precision and performing all training processes in half-precision before updating the model in single precision again. By doing so, we can achieve more efficient computations and memory usage during training, while still ensuring that the final model retains the same level of accuracy as if it had been trained using full precision throughout the entire process.

Keep a copy of the model weight

The majority of weight updates, which are obtained by multiplying the learning rate with gradients, have a very small magnitude, often less than 6 e-8 as shown in the above figure. However, these small values would become zero in the FP16 format, making it unsuitable for performing the weight update process.

Therefore, it is advisable to maintain a single-precision master copy of the weights that accumulate the gradients after each optimizer step. This copy is then converted back to half-precision format for the forward- and back-propagation steps.

During the training of a neural network, the Forward pass, Backward pass, and Gradient calculations are performed in half-precision to speed up the process. However, the updated weights are then converted back to single precision before being updated to the model. This ensures that the model remains accurate by preventing any loss of precision due to repeated rounding off of values during training (source).

A simple algorithm is,

  1. Maintain a primary copy of weights in FP32.
  2. For each iteration:
  • Make an FP16 copy of the weights.
  • Forward propagation (FP16 weights and activations).
  • Backward propagation (FP16 weights, activations, and their gradients).
  • Complete the weight update (including gradient clipping, etc.).

Scaling loss to scale gradients

While half-precision (FP16) can accelerate deep learning model training, there is a limitation to this approach. The histogram of the weights and activation gradients shown below reveals that many values fall outside the FP16 range, causing a considerable number of gradients to become zero. As a result, the model’s accuracy decreases.

To overcome this issue, the solution is to apply scaling to the loss function by multiplying the loss by S (added in logarithmic scale). This scaling operation causes all gradients to also scale by S (using the chain rule), which allows them to fit within the FP16 range. However, it is essential to unscale the gradients before updating the model weights.

Histogram of activation gradient magnitudes throughout FP32 training of Multibox SSD network. Both x- and y-axes are logarithmic.

The updated algorithm will be like this,

  1. Maintain a primary copy of weights in FP32.
  2. For each iteration:
  • Make an FP16 copy of the weights.
  • Forward propagation (FP16 weights and activations).
  • Multiply the resulting loss with the scaling factor S.
  • Backward propagation (FP16 weights, activations, and their gradients).
  • Multiply the weight gradient by 1/S.
  • Complete the weight update (including gradient clipping, etc.).
When using half-precision (FP16), the loss function can sometimes diverge. However, the FP16 training loss is replicating the FP32 training loss if we keep scaling the loss. This allows for the benefits of half-precision to be leveraged without compromising the accuracy of the model. (source).

How it performs?

Mixed precision training has proven to be a highly effective approach for deep learning, achieving up to 8x faster computation times without sacrificing model accuracy. However, this approach does require slightly more memory(1.5x) to store weights in both half and single-precision formats.

Accuracy for some of the deep learning models in classification tasks (Source). We observe that model accuracy is more or less the same with mixed precision(FP16 & FP32) to baseline(FP32) methods.
Accuracy for some of the deep learning models in object detection tasks (Source).

Python Implementations

All popular deep learning framework has it implemented. For example, keras implementation of mixed precision training and auto-scale of loss can be implemented as below.

opt = tf.keras.optimizers.Adam()
opt = tf.train.experimental.enable_mixed_precision_graph_rewrite(opt)
model.compile(loss=loss, optimizer=opt)
model.fit(…)

The deepspeed Python package by Microsoft has incorporated mixed precision training into its frameworks, which allows for different settings and loss scaling options.

Here are some examples of configurations that include available json keys.

"bf16": {
"enabled": true
}
"fp16": {
"enabled": true,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
}

We can also use PyTorch lightning to implement mixed precision training.

References and some other good articles

Thank you for taking the time to read my writing. If you enjoyed it, I invite you to follow me for future content. Additionally, feel free to connect with me on LinkedIn.

You might like my other articles,

--

--