The Power of Quantization in ML: A PyTorch Tutorial Part 2

Ebad Sayed
7 min readJul 1, 2024

--

https://cdn.prod.website-files.com/64d9e7e32e307274f238b1ae/65b7d39ab1916d44837f2c91_blog_post_header_07-5.png

In the previous article we saw what is quantization, what are the different data types and how can we load ML models in different data types. In this article we will learn about Linear Quantization.

Previous Article: Mastering Quantization Part 1

Linear Quantization

Quantization is a process of mapping a large set to a small set of values. For example applying 8-bit linear quantization on the following matrix:

Image by Author

we can map the most positive number in the matrix (728.6) to the maximum value that int8 can store, which is 127. Similarly the most negative number (-184) to -128. Then by following a linear mapping we can map the rest of the values. After this we can delete the original tensor to free up the space and end up with the quantized tensors with parameters s (scale) and z (zero point) that we used to perform linear mapping.

How can we go the other way back to the original tensor?

We can apply the same mapping but we won’t get the same values. That means quantization results in loss of information. By applying the same linear qunatization on the quantized matrix we get the dequantized matirx.

Image by Author

We can see that the values of the original matrix and dequantized matrix are approximately the same. The error matrix is the difference of original and dequnatized one and we can see the error is not zero but not too bad either.

Quantizing the Model

We will use Google Flan-T5 Model which almost contains 75 million parameters, each parameter is FP32 (as 8-bit=1 byte; 32 bits=4 bytes). 75x1⁰⁶x4 = 300 million bytes (MB) = 0.3 GB.

model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
module_sizes = compute_module_sizes(model)
print(f"The model size is {module_sizes[''] * 1e-9} GB")
# OUTPUT --> The model size is 0.307844608 GB

As the ML model has many layers we will try to quantize only the linear layers. To quantize the model we just need to call quantize() method of quanto library. Here we will just quantize the weights into integers not the activations.

from quanto import quantize, freeze
quantize(model, weights=torch.int8, activations=None)
freeze(model)
module_sizes = compute_module_sizes(model)
print(f"The model size is {module_sizes[''] * 1e-9} GB")
# OUTPUT --> The model size is 0.12682868 GB

After this all the Linear layers are now replaced by QLinear (quantized linear). After this to get the quantized model we just need to call freeze() method. The performance remains similar.

Intermediate State

The quanto library creates an intermediate state after we call quantize. Then we call freeze to get the quantized weights. These intemediate are useful when we run the inference on a model by passing an input such as image, or text, etc the activation of the model will vary dpeending on the input to get good linear paramters.

Calibration

  1. Calibrate model when the activations of the model:
    - Range of activation values depends on what input was given.
    - eg. a different input text will generate different activations.
  2. Min/Max of activation ranges are used to perform linear quantization.
  3. How to get min and max arange to activation?
    - gather sample input data.
    - run inference
    - calculate min/max of activations

Qunatization Aware Timing

Training in a way that controls how the model performs once it is quantized.

  1. Intermediate stat holds both: A quantized version of wights and Orignial unquantized weights.
  2. Forward pass (inference): Use quantized version of model weights to make predictions eg. BF16.
  3. Back propagation (updating model weights): Update original, unquantized version of model weights eg. in FP32.

Theory for Linear Quantization

Image by Author

Linear Quantization uses a linear mapping to map the higher precision range FP32 to a lower precision range INT8. There are 2 parameters in linear quantization: scale (s) and zero point (z). The scale is stored in the same data type as the original tensor, and z is stored in the same data type as the quantized tensor.
Ex :- s=2 and z=0 → r = 2q, so for q=10, r=20. If we have a quantized value of 10 we get the de-quantized value as 20. First we have the original tensor and we need to quantize this tensor. To get the quantized tensor we need to isolate q and we get the formula :-

Image by Author

Quantization with Random Scale and Zero Point

We will try to implement Linear Quantization for when the “scale” and the “zero point” are known/randomly selected.

def linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype = torch.int8):

scaled_and_shifted_tensor = tensor / scale + zero_point

rounded_tensor = torch.round(scaled_and_shifted_tensor)

q_min = torch.iinfo(dtype).min
q_max = torch.iinfo(dtype).max

q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)

return q_tensor

We define a function which takes multiple arguments (tensor, scale, zero_point and dtype). After this the first step is to get the scaled and shifted tensor as (r/s + z) and then we round it. Then we need to make sure that the qunatized value is between the minimum and maximum quantized value. For this we will use the .iinfo() method to get the min max value of the quantized dtype. Then the quantized tensor will be clamped between the min max values and convert it to the required dtype.

We define the matrix (input tensor) and we initialize scale and zaro_point with some random values.

test_tensor=torch.tensor(
[[191.6, -13.5, 728.6],
[92.14, 295.5, -184],
[0, 684.6, 245.5]])

scale = 3.5
zero_point = -70

quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale, zero_point)
quantized_tensor

# tensor([[ -15, -74, 127],
# [ -44, 14, -123],
# [ -70, 126, 0]], dtype=torch.int8)

Dequantization with Random Scale and Zero Point

dequantized_tensor = scale * (quantized_tensor.float() - zero_point)
dequantized_tensor

# tensor([[ 192.5000, -14.0000, 689.5000],
# [ 91.0000, 294.0000, -185.5000],
# [ 0.0000, 686.0000, 245.0000]])

Let’s try to dequantize without casting to float()

scale * (quantized_tensor - zero_point)
# tensor([[ 192.5000, -14.0000, -206.5000],
# [ 91.0000, 294.0000, -185.5000],
# [ 0.0000, -210.0000, 245.0000]])
def linear_dequantization(quantized_tensor, scale, zero_point):
return scale * (quantized_tensor.float() - zero_point)

dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
dequantized_tensor

# tensor([[ 192.5000, -14.0000, 689.5000],
# [ 91.0000, 294.0000, -185.5000],
# [ 0.0000, 686.0000, 245.0000]])

Quantization Error

from helper import plot_quantization_errors
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
Image by Author

Note: For the plot above, Quantization Error Tensor = abs(Original Tensor — Dequantized Tensor).
Next we will calculate an “overall” quantization error by using Mean Squared Error technique.

dequantized_tensor - test_tensor
# tensor([[ 0.9000, -0.5000, -39.1000],
# [ -1.1400, -1.5000, -1.5000],
# [ 0.0000, 1.4000, -0.5000]])
(dequantized_tensor - test_tensor).square()
# tensor([[8.0999e-01, 2.5000e-01, 1.5288e+03],
# [1.2996e+00, 2.2500e+00, 2.2500e+00],
# [0.0000e+00, 1.9601e+00, 2.5000e-01]])
(dequantized_tensor - test_tensor).square().mean()
# tensor(170.8753)

The quantization error is so high because we are assigning a random value to scale and zero point.

Finding Scale and Zero Point for Quantization

Image by Author

Why make z an integer?

The goal behind this choice is to represent zero in the original range as an integer in the quantized range. So when we quantize zero it will take the value z in the quantized range and if we dequantize the value z, it will become zero again.

What if z is out of range?

Since we need to cast z to the quantized data type.
If z < 𝑞𝑚𝑖𝑛 then z = 𝑞𝑚𝑖𝑛
If z > 𝑞𝑚𝑎𝑥 then z = 𝑞𝑚𝑎𝑥
So this way we don’t have overflow and underflow.

q_min = torch.iinfo(torch.int8).min # --> -128
q_max = torch.iinfo(torch.int8).max # --> 127

r_min = test_tensor.min().item() # --> -184.0
r_max = test_tensor.max().item() # --> 728.5999755859375

scale = (r_max - r_min) / (q_max - q_min) # --> 3.578823433670343
zero_point = int(round(zero_point)) # --> -77

We can create a function that can perform the above operations.

def get_q_scale_and_zero_point(tensor, dtype=torch.int8):

q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
r_min, r_max = tensor.min().item(), tensor.max().item()

scale = (r_max - r_min) / (q_max - q_min)

zero_point = q_min - (r_min / scale)

# clip the zero_point to fall in [quantized_min, quantized_max]
if zero_point < q_min:
zero_point = q_min
elif zero_point > q_max:
zero_point = q_max
else:
# round and cast to int
zero_point = int(round(zero_point))

return scale, zero_point
new_scale, new_zero_point = get_q_scale_and_zero_point(test_tensor)
new_scale # --> 3.578823433670343
new_zero_point # --> -77

We will use the previously made function, linear_q_with_scale_and_zero_pointand linear_dequantization.

quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, new_scale, new_zero_point)
dequantized_tensor = linear_dequantization(quantized_tensor, new_scale, new_zero_point)
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
Image by Author
(dequantized_tensor-test_tensor).square().mean()
# OUTPUT --> tensor(1.5730)

Create Our Own Linear Quantizer

def linear_quantization(tensor, dtype=torch.int8):
scale, zero_point = get_q_scale_and_zero_point(tensor, dtype=dtype)

quantized_tensor = linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype=dtype)

return quantized_tensor, scale , zero_point
r_tensor = torch.randn((4, 4))
r_tensor

# tensor([[ 0.4932, 0.5593, 0.3367, -0.5350],
# [-0.2344, -0.2833, -0.2291, 0.7843],
# [-2.1898, -0.2059, -0.7546, 0.8257],
# [ 0.1114, 0.2245, -0.4224, -0.0353]])
quantized_tensor, scale, zero_point = linear_quantization(r_tensor)
quantized_tensor

# tensor([[ 99, 104, 85, 12],
# [ 37, 33, 38, 123],
# [-128, 40, -7, 127],
# [ 66, 76, 21, 54]], dtype=torch.int8)
scale # --> 0.0118255521736893
zero_point # --> 57
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
plot_quantization_errors(r_tensor, quantized_tensor, dequantized_tensor)
Image by Author
(dequantized_tensor-r_tensor).square().mean()
# OUTPUT --> tensor(1.0914e-05)

--

--

Ebad Sayed

I am currently a final year undergraduate at IIT Dhanbad, looking to help out aspiring AI/ML enthusiasts with easy AI/ML guides.