Constructing LoRA: Low Rank Adaptation with Mathematical Insights and Practical Implementation from Scratch in PyTorch

Vipul Sarode
6 min readMar 27, 2024

--

In this article, we will learn what LoRA is, look into the maths that makes LoRA fine-tune large models efficiently, and finally create our own LoRA from scratch and use it to fine-tune our model.

How does LoRA work?

LLMs or other large models like Stable Diffusion have billions and billions of parameters. It is only possible to fine-tune these models to a specific business use case if they have a big enough budget and heaps of computational resources. Hence, in the paper “LoRA: Low Rank Adaptation of LLMs”, Microsoft proposed LoRA to help solve the problem.

Hence, to fine-tune, rather than loading the whole model into the GPU, performing backpropagation on the whole model, and updating all its weights LoRA freezes the initial weights, W, of the model and creates two additional matrices, A and B, on top of it. These matrices, A and B, are created such that when multiplied together, they create a new matrix that has the same dimensions as the initial weight matrix, W. In the training process after calculating the loss, the loss is only backpropagated to the LoRA matrices.

This is a classic case of matrix decomposition. Utilizing this process, we need to train and save a smaller number of parameters. Fewer parameters mean fewer storage requirements, less computation required, faster backpropagation, and consequently, less training time. Now, let’s look at the maths behind the LoRA to understand how it works.

The Maths behind LoRA

As mentioned earlier, LoRA works the way it works because of the matrix decomposition. This is the forward pass formula while training with LoRA

h = W0x + ∆W x = W0x + BAx where h is the hidden layer, W0 is the pre-trained frozen weights of the pre-trained model of shape (d x k), ∆W are the LoRA tracked weights, and B and A are the new matrices created of the dimension (d x r) and (r x k). The B and A matrices are created in such a way that after matrix multiplication, we get a matrix of shape (d x k).

As we can see in the picture, to create a (3x3) matrix with updated weights and 9 parameters, we only need to update 6 parameters, 3 of B and A each. It may not seem a significant number for this specific matrix, but while working with an actual model the number of trainable parameters is decreased drastically.

There have been recent developments in LoRA such as the LoRA+ which was proposed by Soufiane Hayou, Nikhil Ghosh, and Bin Yu of UC Berkeley. LoRA+ is the same as LoRA but with different learning rates for the B and A matrices.

Now that we are well-acquainted with LoRA and its internal workings, let’s build one from scratch and deepen our understanding.

Initialization of the A and B matrix from the original paper
#Let's create a LoRA class that will add two new matrices A and B to the
#original weights and return them such as B x A yields the same dinmensions as W

class LoRA(nn.Module):
def __init__(self, features_in, features_out, rank, alpha, device = device):
super().__init__()
self.matrix_A = nn.Parameter(torch.zeros((features_out, rank)).to(device))
self.matrix_B = nn.Parameter(torch.zeros((rank, features_in)).to(device))
self.scale = alpha/rank
nn.init.normal_(matrix_A, mean = 0, std = 1)

def forward(self, W):
return W + torch.matmul(self.matrix_B,self.matrix_A).view(W.shape)*self.scale

Now that we have created our own LoRA class that will set up the matrices on top of the original Weights, we need to create a function that will replace the original weights in layer with the output of the LoRA class i.e. with the two new matrices added on top of the original weight matrix.

#This function takes the layer as the input and sets the features_in.features_out
#equal to the shape of the weight matrix. This will help the LoRA class to
#initialize the A and B Matrices

def layer_parametrization(layer, device, rank = 1, lora_alpha = 1):
features_in, features_out = layer.weight.shape
return LoRA(features_in, features_out, rank = rank, alpha = lora_alpha, device)

We can easily apply this function to the layer of our model using the Parametrize() function from PyTorch Library. Go here if you were not aware of this function.

We have successfully created a LoRA from scratch. The only thing that remains is its practical implementation and hands-on experience with the efficiency of LoRA. For ease, we will train a model to classify the MNIST digits and fine-tune the model on a specific digit. For the full code, please go to my GitHub repository.

The original classifier without any fine-tuning gave the following results:

import torch.nn.utils.parametrize as parametrize

# Here we apply parametrization such that whenever the model wants to access
# weights from the original linear layers of the model,
# it returns the original weights plus LoRA matrices so that we can freeze the
# original weights and then train the LoRA matrices.

parametrize.register_parametrization(exp.linear1, 'weight', layer_parametrization(exp.linear1, device))
parametrize.register_parametrization(exp.linear2, 'weight', layer_parametrization(exp.linear2, device))
parametrize.register_parametrization(exp.linear3, 'weight', layer_parametrization(exp.linear3, device))

Just for reference, here is the additional number of parameters introduced by the LoRA. We can see only 0.242% of the parameters have been increased, making fine-tuning with LoRA efficient.

Putting all of the code here is not a very efficient use of time, hence please go to my GitHub for the full code used in this article. I have only attached the main snippets.

Now, let’s fine-tune the model to increase its accuracy in classifying the digit 7. We will upload the MNIST dataset again where it only contains the digit 7 and train a new model, but with a LoRA this time.

#freezing the non-LoRA matrices.

for name, param in exp.named_parameters():
if 'mat' not in name:
print(f'Freezing non-LoRA parameter {name}')
param.requires_grad = False

for layer in [exp.linear1, exp.linear2, exp.linear3]:
layer.parametrizations["weight"][0].requires_grad = True

# Train the network with LoRA only on the digit 7 and only for 100 batches
train(train_loader, exp, epochs=1, total_iterations_limits=100)

These are the results we get after training the model with LoRA.

The wrong counts for the digit 7 went down from 67 to just 5. This is indeed a great improvement with so few additional parameters.

Thank you for reading! Most of the code is inspired by Umar Jamil’s tutorial, which is a great one and has helped me learn a lot. I have shared this article from the perspective of introducing LoRA to freshers and documenting my learning journey. Also, the best way to learn is through teaching. Please let me know in the comments if you have any questions or feedback. See you in the next one.

References

  1. Microsoft GitHub
  2. Umar Jamil
  3. LoRA: Low Rank Adaptation of LLMs
  4. PyTorch Parametrization
  5. Krish Naik

--

--