LoRA: Low-Rank Adaptation from Scratch — Code and Theory

AR
5 min readAug 3, 2023
Photo by Ticka Kao on Unsplash

Transformer models can have a lot of parameters which can make fine-tuning them an expensive and time-consuming endeavor that is sometimes not even possible on consumer hardware due to memory constraints.

Low-rank adaptation (LoRA) of the linear projection weights can help alleviate these issues by reducing the number of parameters in the model that need updating during the fine-tuning process.

To a get a better understanding of how this works lets first outline the steps involved with traditional fine-tuning:

  1. Forward Pass: Compute some set of predictions by passing input data forward through the network.
  2. Loss: Compare the predictions against the ground truth to obtain a measure of loss or error.
  3. Backward Pass: Compute the gradient of the loss with respect to the weights and biases by working backward through the network.
  4. Update Weights: Nudge the weights and biases in the opposite direction of the gradient to reduce the loss.

If we focus on just the weights of a single linear projection in the model, nudging the weights (step 4) is typically formulated as:

W’ = W + ΔW

where W is the weight matrix, ΔW is the change-in-weights matrix, and W’ are the new weights.

Weight update in the traditional fine-tuning of a linear layer.

For LoRA we are going to formulate this slightly differently but still equivalent.

The difference is that we keep the weight matrix and the change-in-weights matrix separate throughout the fine-tuning process.

So, the forward pass (step 1) is now formulated as follows:

h = W₀x + ΔWx

where W₀ is the weight matrix holding the weights at the start of fine-tuning (these weights remain frozen and do not change), ΔW is the change-in-weights matrix and x is the input vector. Notice that x is independently multiplied with both W₀ and ΔW, and then the two products are added together.

Reformulated forward pass and weight update.

Because we have this separation, we are also going to make the updates to the change-in-weight matrix (step 4) rather than the frozen weight matrix.

For simpler notation let’s also denote the change-in-weight matrix as Wᵩ and formulate the update of it as follows:

Wᵩ’ = Wᵩ + ΔWᵩ

where Wᵩ’ is the new change-in-weight matrix and ΔWᵩ is the change in the change-in-weight matrix.

The reason for this new formulation has to do with the LoRA method operating under the assumption that the change-in-weight matrix has a low intrinsic rank.

To understand the concept of low intrinsic rank we first should talk a little about rank.

The rank of a matrix is the maximum number of linearly independent rows or columns in that matrix.

The rank is important because it defines the lowest dimensional space that all the columns or rows can fit.

Rank deficient matrices — that is matrices with linear dependencies — have redundancy in them. Things with redundancies can usually be represented more compactly.

So, low intrinsic rank refers to the idea that the information contained within the matrix can be represented using fewer dimensions than the original matrix might suggest.

And that is the aim of LoRA — to take the change-in-weight matrix and approximate it with the product of two lower-rank matrices.

It is formulated as follows:

h = W₀x + Wᵩx = W₀x + BAx

where the change-in-weight matrix Wᵩ is now represented by the product of two lower-rank matrices B and A.

LoRA adaptation of a linear layer.

Notice in the below diagram that the dimension r is a hyperparameter that we define so that BA has fewer trainable parameters than Wᵩ. The smaller the value of r the more compressed the representation.

BA as an approximation of the change-in-weight matrix (Wᵩ).
Comparison of rank-4 and rank-1 adaptation. LoRA results in fewer trainable parameters.

Now the beauty of this method is that B and A are parameterized. We do not have to solve for them with matrix decomposition algorithms or anything. The contents of B and A are learned during the fine-tuning process.

Furthermore, after the weights of B and A are learned, the BA matrix can be merged back with the frozen weights through simple addition.

Let’s now put this all into action with code.

The LinearLoRA class defines the architecture of a low-rank adapted linear layer.

In this class we assert that the value of r is greater than 0 since the corresponding dimension in the B and A matrices would be non-existent otherwise [line 28].

Note that the value of r determines the maximum possible rank of the approximated matrices. The actual values of the columns and rows in these matrices are learned by the model so there could end up being linear dependencies resulting in a rank even lower than what we define. But the rank cannot be higher than r.

The frozen pretrained linear projection is assigned to self.pretrained, however, the actual weights of the specific linear module we are adapting will be copied over with a method outside of this class [line 3].

We only freeze the weights [line 32]. The bias is kept as a trainable parameter. Whether or not to freeze the bias shouldn’t have too much impact but is something that can be experimented with.

Matrix A is initialized with this uniform Kaiming since that is how normal linear layers are initialized by default in PyTorch [line 36]. While often overlooked, using the right weight initializations will help the model learn through better gradient flow.

Matrix B is initialized to all zeros because we want B multiplied with A to also be zero and have no effect on the weight matrix for the very first forward pass of the fine-tuning process [line 40].

The output from BA is scaled by alpha/r before it is added to the frozen pretrained output [line 52]. Alpha acts as a learning rate to amplify or dampen how much to update the weights.

There are some additional methods that can be found in the lora_from_scratch.py script on GitHub, which have to do with injecting these adapted linear layers into a model before the fine-tuning process and merging the weights after fine-tuning is complete.

These will be covered in the next part of this series (Coming Soon) when we go to fine-tune our BERT model.

See you soon!

--

--