Binary Magic: Building BitNet 1.58bit Using PyTorch from Scratch

Chidhambararajan R
5 min readMar 12, 2024

--

Spoiler Alert:

The series of small experiments I conducted confirms that 1.58bit models can actually rival the full precision LLMs !!!

Illustration of a 1.58bit quantized LLM rivalling full precision LLMs interms of modelling i/o
Dalle 3 Generated Ilustration of a barebones highly quantized LLM fighting a full precision LLM

What is Quantization and Why you need it?

Quantization is the process of representing a float number with fewer bits. When 2 numbers are quantized with the same bit count, the computational cost of floating point operations on the numbers reduces almost by the same factor of number of bits reduced (In Theory). This allows us the increase speed and reduce the ram consumption of ML models. But this often results in information loss resulting in accuracy decrease which we can recover to an extend by finetuning the quantized model a bit more

Existing Quantization Approaches Vs BitNet 1.58bit

Most quantization algorithms present in the market requires a pretrained model in full precision. And one would apply techniques such as Post Training Quantization and Quantization Aware Training for such algorithms to work effectively.

PTQ is a quantization technique where the model is quantized after it has been trained. QAT is a finetuning of the PTQ model, where the model is further trained with quantization in mind. — Deci.AI Article

BitNet takes a radically different approach wherein the model is trained from scratch with quantization !!

BitNet’s Quantization Algorithm

In the above image a weight clipping threshold γ is calculated by taking half of mean of absolute values (assuming n=2). Then the weight matrix W is divided by the same which results in the new weight matrix having values ≥ 1 when original weight values are ≥ γ, ≤ -1 when original weight values are ≤ -γ. For values between -γ, γ values gets mapped to -0.99999.. to 0.9999…

When roundclip is performed

for original values ≥ γ new value is 1.0, original values ≤ -γ new value is -1.0, original values between -γ and γ new value is 0.0.

The resultant values in theory can be represented with 1.58bits by information encoding theory. Since bits can’t be fractional we can represent them in 2 bits.

Quantization Function Implementation in Pytorch

Threshold calculation:

def compute_adjustment_factor(self, input_tensor: torch.Tensor):
absmean_weight = torch.mean(torch.abs(input_tensor))
adjustment_factor = 1e-4 + absmean_weight * 2 # 1e-4 to avoid zero divison error
return adjustment_factor

I made a minor mistake here, instead of halving the absmean, I multiplied it by 2. Still the experiment worked!

RoundClip (1.58~= 2bit)

def compute_2bit_quantized_tensor(self, input_tensor: torch.Tensor):
twobit_matrix = torch.clip(input=torch.round(input_tensor), min=-1, max=1)
return twobit_matrix

def compute_1bit_quantized_tensor(self, input_tensor: torch.Tensor):
return torch.sign(input_tensor)

def compute_quantized_tensor(self, input_tensor: torch.Tensor):
if self.quantization_mode == QuantizationMode.two_bit:
return self.compute_2bit_quantized_tensor(input_tensor)
else:
return self.compute_1bit_quantized_tensor(input_tensor)

Quantization Step

weight_adjustment_factor = self.compute_adjustment_factor(self.weight)
adjusted_weight = self.weight / weight_adjustment_factor
quantized_weight = self.compute_quantized_tensor(adjusted_weight)

Linear Layer Operation

F.linear(weight_adjustment_factor * x, quantized_weight, self.bias)
# adjustment factor is multiplied with input and quantized weight was divided of the same

But the model wont learn !!!!

If the weights are quantized before passing them on the linear layer function, the updations to the quantization matrix wont pass through the quantization function (as most updations would be between 1e-4 to 1e-2 which will become zero when backpropagated through the quantization step). Because of which the original weight matrix would never be updated and the model would never learn!!

But there is a neat engineering trick for the same

This is how the full forward block looks like

def forward(self, x):
weight_adjustment_factor = self.compute_adjustment_factor(self.weight)
adjusted_weight = self.weight / weight_adjustment_factor

if self.training:
quantized_weight = (
adjusted_weight
+ (
self.compute_quantized_tensor(adjusted_weight) - adjusted_weight
).detach()
)
else:
quantized_weight = self.compute_quantized_tensor(adjusted_weight)

return F.linear(weight_adjustment_factor * x, quantized_weight, self.bias)

The values of the quantized weight block with or without self.training set to True would be the same. But when the self.training is set to True, the gradients computed to quantized_weight are beautifully copied to the adjusted weight. Which allows the adjusted weight to get updated during training and the original weight matrix as well as a result.

I borrowed this simple yet mind blowing trick from in VQ VAE pytorch implementation from Google DeepMind ( VQ VAE is the holy grail for any discrete learning)

Experiment Results With Custom Pytorch Implementation

The below experiments are conducted such that a small model and a sufficiently large dataset w.r.t to the small model is selected. Moreover to create the quantized variants of the target model I simply use this code block which replaces nn.Linear blocks to my custom implementation

import copy

def create_quantized_copy_of_model(
input_model: nn.Module, quantization_mode: QuantizationMode
):
model_copy = copy.deepcopy(input_model)
hash_table = {n: m for n, m in model_copy.named_modules()}

for key in list(hash_table.keys()):
if isinstance(hash_table[key], nn.Linear):
new_module = BitNetLinearLayer(
in_features=hash_table[key].in_features,
out_features=hash_table[key].out_features,
bias=hash_table[key].bias is not None,
quantization_mode=quantization_mode,
)
name_chain = key.split(".")
parent_module_attr_name = ".".join(name_chain[:-1])
parent_module = hash_table[parent_module_attr_name]
setattr(parent_module, name_chain[-1], new_module)
for n, m in model_copy.named_modules():
assert not isinstance(m, nn.Linear)
return model_copy

Mnist with a 4 layer feed forward neural net

Fashion MNIST with 128 dim vector 6 block version of VIT

CIFAR100 with 128 dim vector, 8 block version of VIT

In the above experiments we can see that except for the first experiment the 2bit and 1bit variants of the models performed as good as the full precision normal variants of the model. Its possible that catastrophic forgetting might have happened with the quantized models in the first experiment, which gets circumvented by the residual connections of the remaining experiments

Ofcourse these experiments are not performed with LLMs but are just enough to test the paper’s claims on such a system being able to compete with full precision models

My implementation does not store the quantized weights in the 2bit matricies and computations are still performed in fp32, to really the see the computational speedup we will need specialized computation kernels for the same. My implementation just validates the potential claims

All the code for the above experiments and module code can be found at my github repo https://github.com/TheSeriousProgrammer/SimpleBitNet The model training logs can be found here https://wandb.ai/chidha1434/BitNet?workspace=user-chidha1434

Leave a Star in github and a clap here if you like my work. If you feel that there is some inconsistencies in my implementation feel free to mention the same in github issue or in comments here (Lets learn together).

If you are super impressed and want to me to work with you, I am always open for new research opportunities. Kindly contact me in LinkedIn DM https://www.linkedin.com/in/chidha1434

--

--