Post-training Static Quantization — Pytorch

Sanjana Srinivas
3 min readAug 1, 2020

--

For the entire code checkout Github code.

Quantization refers to the technique of performing computations and storing tensors at lower bit-widths than floating-point precision. The mapping between floating and fixed-point precision is as follows:

x=floating point input, Q=quantized output

For detailed maths involved in this process refer to the below link.

Quantization in PyTorch supports conversion of a typical float32 model to an int8 model, thus allowing:

  1. Reduction in the model size.
  2. Reduction in memory bandwidth requirements.
  3. On-device int8 computations are faster compared to float32. Quantization is primarily a technique to speed up inference (only the forward pass is supported for quantized operators).

However, quantization results in approximation and thus results in slightly reduced accuracy.

Modes of Quantization

  1. Dynamic quantization.
  2. Post-training static quantization.
  3. Quantization aware training.

Operators and Hardware Support

Quantization support is restricted to a subset of available operators. Refer to PyTorch documentation on quantization for operation coverage. The set of available operators and the quantization numerics also depend on the backend being used to run quantized models. Currently, quantized operators are supported only for CPU inference in the following backends: x86 and ARM. One can specify the backend by doing:

# 'fbgemm' for server, 'qnnpack' for mobile
backend = 'fbgemm'
my_model.qconfig = torch.quantization.get_default_qconfig(backend)

However, quantization aware training occurs in a full floating-point and can run on either GPU or CPU. Quantization aware training is typically only used when post-training static or dynamic quantization doesn’t yield sufficient accuracy.

Post-training static quantization

This article mostly dwells on the implementation of static quantization. In this method, we need to first tweak the model and calibrate on the training data to get the right scale factor. Further, we quantize weights and activation in the model. The implementation of the same using Resnet18 architecture is available here.

Tweaking the model

Fusing the modules

torch.quantization.fuse_modules is used to fuse [conv, bn] or [conv, bn, relu] or combination of layers specified in the documentation. This can be used to reduce the model size (thus reducing memory access) and decrease the number of operations.

Prepare model for quantization

torch.quantization.prepare will attach observers to the model. This will calibrate the training data. Calibration helps in computing the distribution of different activation. These distributions are then used to determine how activations should be quantized at inference time. Importantly, this additional step allows us to pass quantized values between operations instead of converting these values to floats — and then back to ints — between every operation, resulting in a significant speed-up.

Convert model

torch.quantization.convert converts the floating-point model to a quantized model.

In case of reduced accuracy!

As mentioned earlier, quantization might result in reduced accuracy. In such cases, we can significantly improve the accuracy simply by using a different quantization configuration. The default quantization configuration will use the MinMax observer, to improve accuracy we rather use the Histogram observer. To do this, we can repeat the testing exercise with the recommended configuration for quantizing for x86 architectures.

# 'fbgemm' for server, 'qnnpack' for mobile
my_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

This ‘fbgemm’ configuration does the following:

  • Quantizes weights on a per-channel basis.
  • Uses a histogram observer that collects a histogram of activations and then picks quantization parameters in an optimal manner.

Experiment results

The code is available in Github. The results are computed on ResNet18 architecture using the MNIST dataset.

Results for post-training static quantization on Resnet18 architecture using the MNIST dataset.

Results for post-training static quantization on Resnet18 architecture using MNIST dataset.

--

--