The Power of Quantization in ML: A PyTorch Tutorial Part 4

Ebad Sayed
6 min readJul 1, 2024

--

https://cdn.prod.website-files.com/64d9e7e32e307274f238b1ae/65b7d39ab1916d44837f2c91_blog_post_header_07-5.png

In the previous article we learned about symmetric, asymmetric, per channel and per group quantization and weights & activations quantization. In this article we will learn to build our own custom 8-Bit Quantizer, which can quantize any model in 8-bit precision using per channel quantization scheme. The quantizer is modality agnostic meaning we can apply it on vision, audio, text, and even multimodal models. We will also try to quantize Hugging Face models.

Previous Article: Mastering Quantization Part 3

Custom 8-Bit Quantizer

So first we will create a W8A16LinearLayer class to store 8-bit weights and scales. Then we will replace all the torch.nn.Linear layers with W8A16LinearLayer. Then we will build a quantizer and quantize a model end-to-end and test the naive absmax quantization on many scenario and study its impact.

import torch
import torch.nn as nn
import torch.nn.functional as F

random_int8 = torch.randint(-128, 127, (32, 16)).to(torch.int8)
random_hs = torch.randn((1, 16), dtype=torch.bfloat16)
scales = torch.randn((1, 32), dtype=torch.bfloat16)
bias = torch.randn((1, 32), dtype=torch.bfloat16)

NOTE: weight matrix has the shape (output dimension, input dimension). When we perform the matrix mulitplication between the int8 matrix and the hidden states, we will have a vector of batch size output dimension. So it is important that the scales have the same shape as the output shape of the weight matrix and same for the bias.

F.linear(random_hs, random_int8.to(random_hs.dtype))
# tensor([[-125.0000, 240.0000, -346.0000, -524.0000, -436.0000, -185.0000,
# -150.0000, -17.6250, 160.0000, -330.0000, -330.0000, 328.0000,
# -225.0000, 418.0000, 580.0000, -136.0000, -122.0000, 31.1250,
# 32.2500, -107.5000, 169.0000, 36.0000, 276.0000, -33.5000,
# -380.0000, 143.0000, 97.0000, -162.0000, -199.0000, 74.0000,
# 159.0000, -612.0000]], dtype=torch.bfloat16)

(F.linear(random_hs, random_int8.to(random_hs.dtype)) * scales) + bias

First we have cast the weights into the same data type as the hidden states. Then on top of this we will perform matrix multiplication via F.linear() function from PyTorch. Then we will multiply this with the input scales and optionally add a bias term.

def w8_a16_forward(weight, input, scales, bias=None):

casted_weights = weight.to(input.dtype)
output = F.linear(input, casted_weights) * scales

if bias is not None:
output = output + bias

return output


print("With bias:\n\n", w8_a16_forward(random_int8, random_hs, scales, bias))
# With bias:
# tensor([[ -3.2031, 286.0000, -266.0000, -752.0000, -155.0000, -158.0000,
# -132.0000, -13.8750, 100.0000, -430.0000, 107.5000, -416.0000,
# -33.0000, -104.0000, 26.2500, -80.5000, -178.0000, 38.7500,
# -9.4375, -28.3750, 104.5000, 5.0312, -55.7500, 0.8750,
# 124.0000, -40.2500, 97.0000, 220.0000, 12.6250, 133.0000,
# 3.2188, 520.0000]], dtype=torch.bfloat16)


print("\nWithout bias:\n\n", w8_a16_forward(random_int8, random_hs, scales))
# Without bias:
# tensor([[ -3.7188, 284.0000, -264.0000, -752.0000, -154.0000, -158.0000,
# -131.0000, -14.3125, 100.5000, -430.0000, 109.0000, -416.0000,
# -32.7500, -103.5000, 24.7500, -81.5000, -178.0000, 39.5000,
# -9.0000, -28.5000, 103.0000, 4.5625, -54.7500, 2.4844,
# 124.0000, -40.5000, 97.0000, 220.0000, 11.6875, 132.0000,
# 2.5000, 520.0000]], dtype=torch.bfloat16)
class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()

self.int8_weights = nn.Parameter(torch.Tensor([0, 1]).to(dtype=torch.int8))

try:

W8A16LinearLayer(1, 1)

except Exception as error:
print("\033[91m", type(error).__name__, ": ", error, "\033[0m")

This will give error, “RuntimeError : Only Tensors of floating point and complex dtype can require gradients”.
When we create an nn.parameter layer, PyTorch expects that parameter where it is able to compute gradients on it. We can't explicitly compute gradients on int8 tensors yet. So we should get an error.

class W8A16LinearLayer(nn.Module):
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
super().__init__()

self.register_buffer(
"int8_weights",
torch.randint(-128, 127, (out_features, in_features), dtype=torch.int8))

self.register_buffer("scales", torch.randn((out_features), dtype=dtype))

if bias:
self.register_buffer("bias", torch.randn((1, out_features), dtype=dtype))
else:

def quantize(self, weights):
w_fp32 = weights.clone().to(torch.float32)

scales = w_fp32.abs().max(dim=-1).values / 127
scales = scales.to(weights.dtype)

int8_weights = torch.round(weights/scales.unsqueeze(1)).to(torch.int8)

self.int8_weights = int8_weights
self.scales = scales self.bias = None

def forward(self, input):
return w8_a16_forward(self.int8_weights, input, self.scales, self.bias)

This is the right approach to store int8 weights is instead of saving attributes as being an endless parameter, is to call a method register_buffer(). This way instead of storing a parameter, we just store a buffer means we don't need to compute gradients on the tensor, and we can initialize it with whatever dtype we want.
Then we will have a linear layer which is working fine and a forward pass.
Then we will add the quantization method. So first upcast the weights into FP32 then find the scale value and make sure that scale has same dtype as input weights. Then using the formula find the int8 weights.

Quantization Pipeline

Replace all of the torch.nn.Linear layers with the W8A16LinearLayer layer. Call quantize on the linear layers using the original weights.

def replace_linear_with_target(module, target_class, module_name_to_exclude):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
old_bias = child.bias

new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
setattr(module, name, new_module)
if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target(child, target_class, module_name_to_exclude)

We can pass the model also module, target class of the new class that we are going to set in replacement to the linear layer and module name to exclude which is name of the module that we are going to exclude in this replacement logic. For better results it is better to keep the last module unquantized.
We are going to simple loop over the modules named children, and if the sub module is an instance of an nn.Linear and we don’t have any name that matches the names that are inside the module name to exclude, then we are going to move forward with the module replacement. So we will get the bias of the sub module in old_bias because we are going to use it to create the new target class.
Then we can create the new module which is target class, the in_features and out_features should be the same as the linear layers and use the same dtype as sub modules weights.
Then we will call set attributes function, we will replace the current attribute of module that has name as 'name' with the new_module.
And if the old module has a bias then we will explicitly set the bias of the new module to old_bias.
Then we will recursively call this method on child module by passing the same arguments.

class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(1, 1)
# Try with bias
self.linear_1 = nn.Linear(1, 1)
# Try without bias
self.linear_2 = nn.Linear(1, 1, bias=False)
# Lm prediction head
self.lm_head = nn.Linear(1, 1, bias=False)


model_1 = DummyModel()
model_2 = DummyModel()

replace_linear_with_target(model_1, W8A16LinearLayer, ["lm_head"])
replace_linear_with_target(model_2, W8A16LinearLayer, [])

For testing purpose we have created a dummy model with two linear layers and one language model head, which is usually the last module in a transformer model. The function changes the layers of model, so we have created two copies one for testing out the model name to exclude feature and another which will replace all linear layer instances with new one.
We have passed an empty list, due to which the function will replace all the layers.

Linear Layer Replacement + Quantization

Modify the replace_linear_with_target function to also perform quantization. Implement replace_linear_with_target_and_quantize.

def replace_linear_with_target_and_quantize(module, target_class, module_name_to_exclude):
for name, child in module.named_children():
if isinstance(child, nn.Linear) and not \
any([x == name for x in module_name_to_exclude]):
old_bias = child.bias
old_weight = child.weight

new_module = target_class(child.in_features, child.out_features, old_bias is not None, child.weight.dtype)
setattr(module, name, new_module)

getattr(module, name).quantize(old_weight)

if old_bias is not None:
getattr(module, name).bias = old_bias
else:
# Recursively call the function for nested modules
replace_linear_with_target_and_quantize(child, target_class, module_name_to_exclude)



model_3 = DummyModel()
replace_linear_with_target_and_quantize(model_3, W8A16LinearLayer, ["lm_head"])

In the same function we will add a line after setting attributes, we will get the attributes and then quantize it.

Quantize any Open Source PyTorch Model

  1. Salesforce/codegen-350M-mono
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_id = "Salesforce/codegen-350M-mono"

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)


pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
print(pipe("def hello_world():", max_new_tokens=20, do_sample=False))

# OUTPUT --> [{'generated_text': 'def hello_world():\n print("Hello World")\n\nhello_world()\n\n# 파'}]
replace_linear_with_target_and_quantize(model, W8A16LinearLayer, ["lm_head"])

# OG: (fc_in): Linear(in_features=1024, out_features=4096, bias=True)
# Quantized: (fc_in): W8A16LinearLayer()
print(pipe("def hello_world():", max_new_tokens=20, do_sample=False)[0]["generated_text"])

# def hello_world():
# print("Hello World")

# hello_world()

# def hello_

2. facebook/detr-resnet-50

processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

previous_memory_footprint = model.get_memory_footprint()
print("Footprint of the model in MBs: ", previous_memory_footprint/1e+6)
# OUTPUT --> Footprint of the model in MBs: 166.524032
plot_results(model, image, results)
Image by Author
replace_linear_with_target_and_quantize(model, W8A16LinearLayer, ["0", "1", "2", "class_labels_classifier"])
plot_results(model, image, results)
Image by Author
new_footprint = model.get_memory_footprint()
print("Footprint of the model in MBs: ", new_footprint/1e+6)
# OUTPUT --> Footprint of the model in MBs: 114.80384


# Memory saved
print("Memory saved in MBs: ", (previous_memory_footprint - new_footprint)/1e+6)
# OUTPUT --> Memory saved in MBs: 51.720192

--

--

Ebad Sayed

I am currently a final year undergraduate at IIT Dhanbad, looking to help out aspiring AI/ML enthusiasts with easy AI/ML guides.