The Power of Quantization in ML: A PyTorch Tutorial Part 3

Ebad Sayed
6 min readJul 1, 2024

--

https://cdn.prod.website-files.com/64d9e7e32e307274f238b1ae/65b7d39ab1916d44837f2c91_blog_post_header_07-5.png

Previously we learn about linear quantization. In this article we will learn about other types of Quantization and weights and activations quantization.

Previous Article: Mastering Quantization Part 2

Symmetric VS Asymmetric

Asymmetric: we map [π‘Ÿπ‘šπ‘–π‘› , π‘Ÿπ‘šπ‘Žπ‘₯] to [π‘žπ‘šπ‘–π‘› , π‘žπ‘šπ‘Žπ‘₯]. Which we were doing uptill now.
Symmetric: we map [β€“π‘Ÿπ‘šπ‘Žπ‘₯ , π‘Ÿπ‘šπ‘Žπ‘₯] to [β€“π‘žπ‘šπ‘Žπ‘₯ , π‘žπ‘šπ‘Žπ‘₯]. Where we can set π‘Ÿπ‘šπ‘Žπ‘₯ = max(|π‘Ÿ_π‘‘π‘’π‘›π‘ π‘œπ‘Ÿ|).
Here we don’t need to use the zero point (z=0), this happens because the FP range and the quantized range are symmetric w.r.t zero. The quantized tensor is simply the original tensor divided by the scale that we run and cast to the data type of the quantized tensor, and the scale S is simply
π‘Ÿπ‘šπ‘Žπ‘₯/π‘žπ‘šπ‘Žπ‘₯.

def get_q_scale_symmetric(tensor, dtype=torch.int8):
r_max = tensor.abs().max().item()
q_max = torch.iinfo(dtype).max

return r_max/q_max

test_tensor = torch.randn((4, 4))
get_q_scale_symmetric(test_tensor)
# OUTPUT --> 0.015278718602938914

Quantization Process

def linear_q_symmetric(tensor, dtype=torch.int8):
scale = get_q_scale_symmetric(tensor)
# in symmetric quantization zero point is = 0
quantized_tensor = linear_q_with_scale_and_zero_point(tensor,scale=scale,zero_point=0,dtype=dtype)

return quantized_tensor, scale

quantized_tensor, scale = linear_q_symmetric(test_tensor)

Dequantization Process

dequantized_tensor = linear_dequantization(quantized_tensor,scale,0)
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
Image by Author
print(f"""Quantization Error : {quantization_error(test_tensor, dequantized_tensor)}""")
# OUTPUT --> Quantization Error : 2.1057938283775002e-05
  1. Utilization of quantized range:
    -
    When using asymmetric quantization, the quantized range is fully utilized.
    - When symmetric mode, if the float range is biased towards one side, this will result in a quantized range where a part of the range is dedicated to valeus that we will never see. (eg. ReLU where the output is positive).
  2. Simplicity: Symmetric mode is much simpler compared to assymetric mode.
  3. Memory: We don’t store the zero-point for symmetric quantization.

Finer Granularity for more Precision

The finer the granularity of quantization, the more accurate the results. However, this increased accuracy comes at the cost of higher memory usage, as more quantization parameters need to be stored. Quantization can be performed at different levels of granularity. For instance, per tensor quantization uses the same scale and zero point for an entire tensor. However, we can achieve more precise results by calculating a separate scale and zero point for each axis, known as per channel quantization. Additionally, we can select groups of n elements and compute a scale and zero point for each group, quantizing each group individually with its own parameters.

Per Channel Quantization

We need to store the scales and the zero point for each row if we decide to quantize along the rows and we need to store them along each column if we decide to quantize along the columns. The memory needed to store all these linear parameters is pretty small. We usually use per channel quantization when quantizing models in 8-bit.

def linear_q_symmetric_per_channel(tensor,dim,dtype=torch.int8):
return quantized_tensor, scale

test_tensor=torch.tensor(
[[191.6, -13.5, 728.6],
[92.14, 295.5, -184],
[0, 684.6, 245.5]])

dim=0 # along rows
output_dim = test_tensor.shape[dim]

scale = torch.zeros(output_dim) # --> tensor([0., 0., 0.])

#Iterate through each row to calculate its scale
for index in range(output_dim):
sub_tensor = test_tensor.select(dim,index)
scale[index] = get_q_scale_symmetric(sub_tensor)

scale # tensor([5.7370, 2.3268, 5.3906])

We now manage to store the scales related to each row inside a tensor. After this we need to do a little bit processing in order to reshape the scale so that when we divide the original tensor by the tensor scale, each column is divided by the correct scale.

scale_shape = [1]*test_tensor.dim() # --> [1,1]
scale_shape[dim] = -1 # --> [-1,1]
copy_scale = scale.view(scale_shape)
copy_scale
# tensor([[5.7370],
# [2.3268],
# [5.3906]])

Let’s understand the process tensor-by-tensor division

m = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
s = torch.tensor([1,5,10])
s.view(1, 3).shape # --> torch.Size([1, 3])

Division along the rows

scale = torch.tensor([[1], [5], [10]])
m / scale
# tensor([[1.0000, 2.0000, 3.0000],
# [0.8000, 1.0000, 1.2000],
# [0.7000, 0.8000, 0.9000]])

Division along the columns

scale = torch.tensor([[1, 5, 10]])
m / scale
# tensor([[1.0000, 0.4000, 0.3000],
# [4.0000, 1.0000, 0.6000],
# [7.0000, 1.6000, 0.9000]])

As we are performing symmetric quantization, hence z=0.

quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale=copy_scale, zero_point=0)
quantized_tensor
# tensor([[ 33, -2, 127],
# [ 40, 127, -79],
# [ 0, 127, 46]], dtype=torch.int8)

Putting it all together inside one function

def linear_q_symmetric_per_channel(r_tensor, dim, dtype=torch.int8):

output_dim = r_tensor.shape[dim]
# store the scales
scale = torch.zeros(output_dim)

for index in range(output_dim):
sub_tensor = r_tensor.select(dim, index)
scale[index] = get_q_scale_symmetric(sub_tensor, dtype=dtype)

# reshape the scale
scale_shape = [1] * r_tensor.dim()
scale_shape[dim] = -1
scale = scale.view(scale_shape)
quantized_tensor = linear_q_with_scale_and_zero_point(
r_tensor, scale=scale, zero_point=0, dtype=dtype)

return quantized_tensor, scale
test_tensor=torch.tensor(
[[191.6, -13.5, 728.6],
[92.14, 295.5, -184],
[0, 684.6, 245.5]])

### along the rows (dim = 0)
quantized_tensor_0, scale_0 = linear_q_symmetric_per_channel(test_tensor, dim=0)

### along the columns (dim = 1)
quantized_tensor_1, scale_1 = linear_q_symmetric_per_channel(test_tensor, dim=1)
dequantized_tensor_0 = linear_dequantization(quantized_tensor_0, scale_0, 0)
plot_quantization_errors(test_tensor, quantized_tensor_0, dequantized_tensor_0)

print(f"""Quantization Error : {quantization_error(test_tensor, dequantized_tensor_0)}""")
# OUTPUT --> Quantization Error : 1.8084441423416138
Image by Author
dequantized_tensor_1 = linear_dequantization(quantized_tensor_1, scale_1, 0)
plot_quantization_errors(test_tensor, quantized_tensor_1, dequantized_tensor_1, n_bits=8)

print(f"""Quantization Error : {quantization_error(test_tensor, dequantized_tensor_1)}""")
# OUTPUT --> Quantization Error : 1.0781488418579102
Image by Author

In previous case we got an error of 2.5, while quantizing along rows we got an error of 1.81 and along columns we are getting 1.08. This is because the outlier values will only impact the channel it was in, instead of the entire tensor.

Per Group Quantization

Now let’s go even smaller and perform group quantization. Here we perform quantization on groups of n elements. Common values for n are 32, 64, 128. Per group can require a lot of memory. If we want to quantize a tensor in 4-bit and we choose a group size equal to 32. We use symmetric mode (z=0), and we store the s in FP16. It means that we are quantizing the tensor in 4.5-bits.

Since we have 4-bit (each element is stored using 4-bit)
And we have 16 / 32 bit (scale in 16 bits for every 32 elements).

def linear_q_symmetric_per_group(tensor, group_size, dtype=torch.int8):

t_shape = tensor.shape
assert t_shape[1] % group_size == 0
assert tensor.dim() == 2

tensor = tensor.view(-1, group_size)

quantized_tensor, scale = linear_q_symmetric_per_channel(tensor, dim=0, dtype=dtype)
quantized_tensor = quantized_tensor.view(t_shape)

return quantized_tensor, scale


def linear_dequantization_per_group(quantized_tensor, scale, group_size):

q_shape = quantized_tensor.shape
quantized_tensor = quantized_tensor.view(-1, group_size)

dequantized_tensor = linear_dequantization(quantized_tensor, scale, 0)
dequantized_tensor = dequantized_tensor.view(q_shape)

return dequantized_tensor
test_tensor = torch.rand((6, 6))
group_size = 3


quantized_tensor, scale = linear_q_symmetric_per_group(test_tensor, group_size=group_size)
dequantized_tensor = linear_dequantization_per_group(quantized_tensor, scale, group_size=group_size)
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)

print(f"""Quantization Error : \ {quantization_error(test_tensor, dequantized_tensor)}""")
# OUTPUT --> Quantization Error : 2.731435415626038e-06
Image by Author

If we see in the quantized tensor, three elements in the matrix along rows we have the maximum value 127. Which shows that we indeed manage to quantize each three elements in the matrix along rows. Also we can see that the quantization error is very very low.

Quantizing Weights & Activations for Inference

How to perform inference with linear quantization? If we only quantize weights the computation will be using FP arithmetic (FP32, FP16, BFP16). Here we need to dequantize the weights to perform the FP computation. If we also quantize activation we will be using integer arithmetic (INT8, INT4, etc). But this is not supported by all hardware.

Let’s see how the linear layer will be if we only quantize weights not the activations. W8A32 means weights in 8-bits and activations in 32-bits. For simplicity, the linear layer will be without bias.

def quantized_linear_W8A32_without_bias(input, q_w, s_w, z_w):
assert input.dtype == torch.float32
assert q_w.dtype == torch.int8

dequantized_weight = q_w.to(torch.float32) * s_w + z_w
output = torch.nn.functional.linear(input, dequantized_weight)

return output



input = torch.tensor([1, 2, 3], dtype=torch.float32)
weight = torch.tensor([[-2, -1.13, 0.42],
[-1.51, 0.25, 1.62],
[0.23, 1.35, 2.15]])


q_w, s_w = linear_q_symmetric(weight)
s_w # --> 0.016929134609192376
q_w
# tensor([[-118, -67, 25],
# [ -89, 15, 96],
# [ 14, 80, 127]], dtype=torch.int8)
output = quantized_linear_W8A32_without_bias(input, q_w, s_w, 0)
print(f"This is the W8A32 output: {output}")
# OUTPUT --> This is the W8A32 output: tensor([-2.9965, 3.8768, 9.3957])


fp32_output = torch.nn.functional.linear(input, weight)
print(f"Output if we don't quantize: {fp32_output}")
# OUTPUT --> Output if we don't quantize: tensor([-3.0000, 3.8500, 9.3800])

--

--

Ebad Sayed

I am currently a final year undergraduate at IIT Dhanbad, looking to help out aspiring AI/ML enthusiasts with easy AI/ML guides.