LoRA: Low-Rank Adaptation from the first principle

Shrish
10 min readJun 22, 2023

--

At the heart of all deep learning models lies a sequence of matrix multiplications, interspersed with the introduction of non-linear functions such as Sigmoid, ReLU, and GeLU. The recent surge in large language models (LLMs) has confirmed their exceptional potential in a variety of applications. However, the practical deployment and fine-tuning of these massive LLMs present significant cost and efficiency challenges for data science teams.

The pace of advancement in this field is staggering, with a steady stream of new research papers and software packages being released. These resources are designed to minimize the cost and duration of training and inferring with LLMs. Considering this rapid progression, it wouldn’t be surprising to see LLMs running on our handheld devices in the near future.

In today’s discussion, I’ll be exploring a groundbreaking paper on Low-Rank Adaptation (LoRA). This research proposes a solution to the cost and efficiency challenges posed by LLMs by freezing the pre-trained model weights and injecting trainable rank decomposition matrices into each layer of the Transformer architecture. This innovative approach dramatically reduces the number of trainable parameters for downstream tasks, resulting in significantly reduced GPU memory requirements and improved training throughput.

By harnessing the power of linear algebra, LoRA provides a more feasible solution for fine-tuning LLMs, without any additional inference latency or compromise on model quality. The paper even presents empirical evidence on the efficacy of LoRA, paving the way for a more streamlined approach to LLM adaptation.

We’ll delve deeper into the details of this paper, aiming to illuminate the foundational concepts behind LoRA and understand its impact and potential applications. Our exploration will also include a discussion on the newly released tools and packages that facilitate the integration of LoRA with PyTorch models. These resources provide invaluable support to data scientists seeking to leverage the advantages of LoRA in their work with RoBERTa, DeBERTa, and GPT-2 models, among others.

Join me as we navigate through the intricate landscape of large language models and explore how LoRA is redefining the way we approach model fine-tuning and deployment.

Photo by Breno Machado on Unsplash

What is the rank of the matrix?

The rank of a matrix is a fundamental concept in linear algebra. It is defined as the maximum number of linearly independent columns (or equivalently, rows) in the matrix. In other words, it tells us the maximum number of dimensions spanned by the vectors represented by the matrix. A crude definition can be the amount of information contained in the matrix.

If the rank of a matrix is equal to its smallest dimension (either the number of rows or columns), the matrix is said to have "full rank". A matrix that does not have full rank is said to be rank deficient.

Here's a simple example:

Consider a 3x3 matrix A:

A = [[1, 2, 3],
[2, 4, 6],
[3, 6, 9]]

You can see that each row in this matrix is a multiple of the others, and therefore they are all linearly dependent. In this case, the rank of the matrix is 1, because only one row is linearly independent.

Let's take another example:

B = [[1, 2, 3],
[4, 0, 6],
[7, 8, 9]]

In this matrix, you can see that no row is a multiple of another, which means all rows are linearly independent. So, the rank of matrix B is 3, as there are 3 linearly independent rows (or columns).

To calculate the rank of a matrix in Python, you can use the numpy.linalg.matrix_rank() function. For example:

import numpy as np

A = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
print("Rank of A:", np.linalg.matrix_rank(A))

B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print("Rank of B:", np.linalg.matrix_rank(B))
Rank of A: 1
Rank of B: 2

This script will print out "Rank of A: 1" and "Rank of B: 2" respectively. Note that in the second case, even though there are no rows or columns that are exact multiples of others, the matrix is still rank-deficient because one row can be represented as a linear combination of the other two. This demonstrates how the rank can be less than the number of rows or columns.

Low-rank approximation of matrix

A matrix A of rank r can be factored uniquely into two separate matrices. This is represented as A = CR. Here, A is an m x n matrix, C is an m x r matrix whose rank is also r, and R is an r x n matrix, also with rank r.

Essentially, the rank of a matrix is the smallest value of r that allows the equation A(nxn) = C(nxr) R(rxn) to hold true. It’s important to note that if r is smaller than the actual rank of the matrix, it’s still possible to find an approximate solution to the equation A(nxn) = C(nxr) R(rxn).

The Rank Factorization Theorem allows us to represent a matrix A as a product of two matrices, capturing the essential structure of A in a potentially more compact and numerically stable manner. It is particularly useful when A is of high rank but can be well-approximated by a matrix of lower rank, which is often the case in applications such as image processing and machine learning.

Let’s put this into practice with an example and Python code:

Consider a 5x5 matrix A:

A = np.array([[19,  9, 12, 19,  8],
[ 0, 0, 0, 0, 0],
[ 3, 1, 0, 3, 0],
[ 6, 2, 0, 6, 0],
[25, 11, 12, 25, 8]])
import numpy as np


A = np.array([[19, 9, 12, 19, 8],
[ 0, 0, 0, 0, 0],
[ 3, 1, 0, 3, 0],
[ 6, 2, 0, 6, 0],
[25, 11, 12, 25, 8]])

print("Rank of A:", np.linalg.matrix_rank(A))

Output:
Rank of A: 2

Let us say we find out some C(5x2) and R(2x5) that satisfy A=CR. (can be found using a technique called singular value decomposition: SVD).

C = np.array([[4, 1],
[0, 0],
[0, 1],
[0, 2],
[4, 3]])

R = np.array([[4, 2, 3, 4, 2],
[3, 1, 0, 3, 0]])

print(" CR is :", C@R)

Output:
CR is : [ [19 9 12 19 8]
[ 0 0 0 0 0]
[ 3 1 0 3 0]
[ 6 2 0 6 0]
[25 11 12 25 8]]

LORa for a linear regression model

Let us try the same trick to compress a simple linear model(y= Wx+b),

import torch
import numpy as np
torch.manual_seed(0)

# Dimensions
n, m = 10, 10 # n: input dimension, m: output dimension

#------------------------
# ignore this part of code I just made it make a rank deficient matrix
# it is highly probable that the W matrix will be a rank 2 through this process
nr,mr= 10, 2
W = torch.randn(nr,mr)@torch.randn(mr,nr)
# ----------------------
print("See how W looks like:\n",W)
b = torch.randn(n)


r= np.linalg.matrix_rank(W)
print("Rank of W:", r)
# Random input x
x = torch.randn(n)

# Compute y = Wx + b
y = W@ x + b

#--------------------------------
# this is just to exact rank factorization, it can be ignored sfely

# Perform SVD on W #
U, S, V = torch.svd(W)

# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :r]
S_r = torch.diag(S[:r])
V_r = V[:, :r].t() # Transpose V_r to get the right dimensions

# Compute C = U_r * S_r and R = V_r
C = U_r@S_r
R = V_r
# -------------------------------------------


# Compute y' = CRx + b
y_prime = (C@R)@x+ b

print("Original y using W:\n", y)

print("y' computed using CR:\n", y_prime)


print("Total parameters of W:\n", W.shape[0]* W.shape[1])

print("Total parameters of C and R :\n", C.shape[0]* C.shape[1] + R.shape[0]* R.shape[1])
Output: 
See how W looks like:

tensor([[-1.0797, 0.5545, 0.8058, -0.7140, -0.1518, 1.0773, 2.3690, 0.8486,
-1.1825, -3.2632],
[-0.3303, 0.2283, 0.4145, -0.1924, -0.0215, 0.3276, 0.7926, 0.2233,
-0.3422, -0.9614],
[-0.5256, 0.9864, 2.4447, -0.0290, 0.2305, 0.5000, 1.9831, -0.0311,
-0.3369, -1.1376],
[ 0.7900, -1.1336, -2.6746, 0.1988, -0.1982, -0.7634, -2.5763, -0.1696,
0.6227, 1.9294],
[ 0.1258, 0.1458, 0.5090, 0.1768, 0.1071, -0.1327, -0.0323, -0.2294,
0.2079, 0.5128],
[ 0.7697, 0.0050, 0.5725, 0.6870, 0.2783, -0.7818, -1.2253, -0.8533,
0.9765, 2.5786],
[ 1.4157, -0.7814, -1.2121, 0.9120, 0.1760, -1.4108, -3.1692, -1.0791,
1.5325, 4.2447],
[-0.0119, 0.6050, 1.7245, 0.2584, 0.2528, -0.0086, 0.7198, -0.3620,
0.1865, 0.3410],
[ 1.0485, -0.6394, -1.0715, 0.6485, 0.1046, -1.0427, -2.4174, -0.7615,
1.1147, 3.1054],
[ 0.9088, 0.1936, 1.2136, 0.8946, 0.4084, -0.9295, -1.2294, -1.1239,
1.2155, 3.1628]])

Rank of W: 2

Original y using W:
tensor([ 1.6207, 2.1148, 2.3849, -2.3917, -1.2117, -4.9171, -3.9770, 0.5812,
-3.2889, -2.9090])

y' computed using CR:
tensor([ 1.6207, 2.1148, 2.3849, -2.3917, -1.2117, -4.9171, -3.9770, 0.5812,
-3.2889, -2.9090])

Total parameters of W:
100

Total parameters of C and R :
40

This example clearly demonstrates that we can achieve identical results using two different models: one with 100 parameters and another with just 40. This is achieved by approximating the weight matrix. In fact, even if we set ‘r’ to a value less than the actual rank of the weight matrix, we can still obtain an acceptably approximated output.

When training a model like this, if we were to train the weight matrix ‘W’ using specific training examples and a loss function, it would be considerably easier and quicker to train matrices ‘C’ and ‘R’ instead. We employ a similar logic when training a language model (LLM) or in fact, any other big models, allowing for cost-effective training.

LORa in LLMs

Some extensive deep-learning models have demonstrated that their weight matrix often resides within low-rank spaces. For example, weight matrices of 1000x1000 dimensions have been observed to exist in spaces as low as rank 10. Former results in a total of 10⁶ trainable parameters while the latter results in just 20000 trainable parameters.

For more clarity, let’s consider training a Language Model (LLM). Suppose we have embedding vectors of 1000 dimensions. This would yield K, Q, and V matrices of 1000x1000 dimensions, each producing 10⁶ trainable parameters.

If this seems complex, simply remember that LLMs involve a series of matrix multiplications, and we aim to compress these matrices to a lower rank to reduce the number of parameters that need to be trained.

The fine-tuning process in this context is understood as a shift in all the W to W’ matrices within the model. We opt to keep W frozen, to avoid disrupting the stability of the base model. Instead, we introduce a change dW and decompose it into two low-rank matrices, A and B (dW = AB). The followings are estimates of trainable parameters if we choose r to be 8,


W'(1000x1000) = W(1000x1000)+ dW(1000x1000)
dW(1000x1000) = A(1000x8) B(8x1000)
1000000 trainable params => 16000 trainable prameters
Our reparametrization. We only train A and B. For above example r=8 and d =1000.

In practice, we don’t have to train all the matrices involved in a Language Model (LLM). We can still achieve competitive performance by just training a few target modules(Just Q, K, V, or any pairs of them) and keeping the rest frozen. Importantly, in LoRA (Low-Rank Adaptation) training, we’re not aiming to find the exact rank of the matrix and then perform rank factorization. Instead, we’re perfectly content with a preassumed value of ‘r’ (a hyperparameter) to attain an approximate solution.

The algorithm, therefore, is straightforward:

1. Freeze the base model (W).
2. Approximate the change in W (dW) as the product of two matrices A and B, where ‘r’ is preassumed for some modules.
3. Train the adapter (the product of A and B) with the base model.
4. Incorporate the trained adapter into the base model (W’ = W + AB) and proceed with inference.

Performances

According to the paper, we do not lose any accuracy from training in LoRA. In fact, in fact, we are gaining it even though we reduced the number of trainable parameters to less than 1 percent of the original.

Performance of different adaptation methods on GPT-3 175B. We report the logical form validation accuracy on WikiSQL, validation accuracy on MultiNLI-matched, and Rouge-1/2/L on SAMSum. LoRA performs better than prior approaches, including full fine-tuning.
Validation accuracy on WikiSQL and MultiNLI with different rank r. To our surprise, a rank as small as one suffices for adapting both Wq and Wv on these datasets while training Wq alone needs a larger r.

Code implementations

Code as simple as the following can implement the above algorithm using peft library. The full version can be accessed here.

from peft import PeftModel
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, set_peft_model_state_dict


model = AutoModelForCausalLM.from_pretrained(
"bigcode/starcoder",
use_auth_token=True,
device_map={"": Accelerator().process_index},
)


# lora hyperparameters
lora_config = LoraConfig(r=8,target_modules = ["c_proj", "c_attn", "q_attn"])


model = get_peft_model(model, lora_config)
training_args = TrainingArguments(
...
)

trainer = Trainer(model=model, args=training_args,
train_dataset=train_data, eval_dataset=val_data)

print("Training...")
trainer.train()

# plugging the adapter into basemodel back
model = PeftModel.from_pretrained("bigcode/starcoder", peft_model_path)

Thank you for taking the time to read my writing. If you enjoyed it, I invite you to follow me for future content. Additionally, feel free to connect with me on LinkedIn.

You might like my other related articles,

References

--

--