LoRA Serving on Amazon SageMaker — Serve 100’s of Fine-Tuned LLMs For the Price of 1

Joaopcmoura
15 min readJan 26, 2024

--

In the past year, Low-Rank Adaptation (LoRA) became widely popular to fine-tune Large Language Models (LLMs) efficiently and cost-effectively. It boasts significant savings in comparison to full-parameter fine-tuning while maintaining similar levels of performance. If you’re new to this concept, I recommend familiarizing yourself with LoRA through existing resources (beginner’s primer, deep-dive for tuning LoRA performance ).

Can we translate the cost benefits of LoRA into the deployment world? If so, can we avoid rebuilding the wheel and being burdened by the mostly undifferentiated aspects of infrastructure management?

Turns out we can! This blog post tries to answer the questions of why and how. In it, we will focus on:

  • Understanding the potential and motivation behind serving adapters
  • Discussing how recent advancements made efficient multi-adapter serving technically feasible
  • Presenting a practical deployment example using LoRAX Server on Amazon SageMaker
  • Exploring the performance optimization potential and possible pitfalls of multi-adapter serving
Fig.1: Generated image, each LoRA is a brain. Prompt: Create an image of a GPU galactic overlord with 1000’s of tiny brains stacked on top of it, handling multiple streams of information into each brain. Make the entity look like a real GPU, mechatronic-looking.

The dawn of multi-adapter serving

While the training cost reductions brought about by LoRA are commendable, to truly capitalize on the economic viability of using LLMs in real life scenarios, we must also address inference costs, which surpass training expenses at any reasonable scale. As a reference, some sources claim that the inference costs of ChatGPT exceed (or used to exceed) its total training costs on a weekly basis. This reality has led to numerous innovations aimed solely at enhancing inference efficiency and cost-effectiveness — such as Grouped Query Attention, Flash Attention and PagedAttention —, which are now widely adopted.

Back to LoRA. Adapters are trained on top of, and therefore coupled to a specific base model. If we have several tasks or end user-specific datasets we want to fine-tune on, a reasonable tactic is to select a capable base model and “specialize” individual adapters on those tasks. Adapters are typically around 1% the size of the base model size, so we get the benefit of having to store fewer GBs-worth of model artifacts — a single full base model, and its task-specific adapters — , instead of a full base model for every task/user. At large scale, the training and storage cost benefits are obvious; however, how to serve many adapters on top of the same base model efficiently is not.

Fig.2: LoRA in a nutshell. The input is forwarded through the base model and the low-rank adapter at each layer it is applied to, and the outputs are summed to create the final hidden state. Source: original LoRA paper.
Eq.1: Forward pass with LoRA. Source: original LoRA paper.

Refer to Fig.2 and Eq.1 from the original LoRA paper to understand how the input is processed through the base model and the low-rank adapter, and the outputs are combined to create the final hidden state.

There are two methods to compute h using LoRA:

  1. Combine the base model and LoRA weights first, then compute W₀ . x
  2. Calculate the base model and adapter contributions to h separately, then sum them

Although there are approaches that explore combining contributions of different adapters, naively merging different adapters’ weights with a base model simultaneously is not the same as composing the capabilities that were learned for each adapter. There is no guarantee of sustained individual task performance; adapters are trained independently, optimized toward different — perhaps competing — local minima of our selected loss function. Therefore, if we go with approach 1) we are limited to processing requests for a single adapter (or user) at any given time . In the worst case scenario batch size will be 1, which is bad news for throughput, GPU resource utilization, and ultimately for the user, who will have to wait for their queued request to be processed sequentially.

Approach 2) seems to be more suitable for a multi-task/tenancy scenario, as we can compute the left term — W . x — once, the right term — B . A . x — for every adapter n, and sum them as needed. Now we can batch the base model’s forward pass, independently from what adapters we want to use! What about the adapters’ forward pass? Are we doomed to compute them sequentially for every adapter?

Systems optimization to the rescue

“No, we’re not” — says Punica (and then S-LoRA a couple of weeks later*). The Punica paper and project introduced a custom CUDA kernel that enables batching sequences for different adapters — the Segmented Gather Matrix-Vector Multiplication (SGMV) kernel. For usage patterns where each request targets a different adapter, Punica achieves 12x the throughput of state of the art inference servers like vLLM, keeping latency and throughput nearly constant with the number of concurrent adapters.

*S-LoRA added different improvements such as a unified memory pool to reduce memory fragmentation and a novel tensor parallel strategy to minimize the communication cost for added LoRA computation

Evolution of production-ready tools for multi-adapter serving

While Punica covers the low-level engine need for efficient multi-adapter inference, an engine is not enough for production-grade model serving. We need a whole car, preferably a fast one. In October 2023, Predibase announced and later open-sourced their LoRAX server, originally forked from HuggingFace’s Text Generation Inference (TGI) server — version 0.9.4. TGI at this point was already a production-ready LLM server, with support for continuous batching (likely the #1 advancement in high throughput LLM serving), efficient CUDA kernels for attention, tensor parallelism, quantization, distributed tracing and more. On top of this solid foundation, LoRAX added adapter-specific features, namely:

  • Heterogeneous continuous batching — schedules requests for different adapters in the same batch; powered by Punica, with the same benefits mentioned above (nearly constant latency and throughput for # adapter s>1);
  • Dynamic adapter loading — add a target LoRA adapter ID to your request, LoRAX will download it from the HuggingFace Hub or S3 (if not locally available) and load it to GPU just-in-time , without blocking other requests;
  • Adapter offloading between GPU and CPU memory — adapters are loaded to GPU memory as they are needed, and get offloaded to CPU memory according to a scheduling policy when GPU memory is saturated. In theory this means you can serve as many adapters as you can hold in disk (not in parallel of course :))! To understand the scheduling in greater detail, read the official Predibase blog post.

You can find documentation on the full set of features in the LoRAX github repo. Fig.3 depicts its heterogeneous continuous batching, where a simple mask ensures that the correct adapter is applied to each request in the batch, when computing the activations for each layer (look at Eq.1 above; the mask is applied to BA_blue . x for inputs targeted at BA_red, and vice-versa).

Fig.3: LoRAX heterogeneous continuous batching. Source: LoRAX launch blog.

LoRAX seems to have taken large strides towards production readiness. However, dealing with underlying infrastructure, GPU’s and appropriate scaling based on real traffic patterns still poses serious challenges, which even large enterprises and seasoned teams struggle with.

Amazon SageMaker Hosting can greatly help on this front, providing managed infrastructure provisioning, auto-scaling and load-balancing, GPU sharing, built-in configurable deployment strategies, and many other useful features. It allows you to focus on the product/feature you’re building, instead of the “must-have” capabilities of hosting a real-time endpoint.

With the above in mind, the rest of this blog post will detail how to deploy LoRAX on Amazon SageMaker, marrying the benefits of the two.

Edit: a day after this blog post was published, multi-LoRA serving support was added to vLLM (see here), a popular high-performance LLM serving library. For those already familiar with or using vLLM, this is worth exploring and comparing to the approach and capabilities of LoRAX. I hope to exemplify multi-LoRA serving with vLLM on SageMaker in a future blog post.

Tutorial — Deploy LoRAX on SageMaker

This section assumes you have an AWS account, as well as AWS CLI installed and configured. It also assumes running code blocks on Jupyter cells.

What we’ll do:

  1. Setup our environment
  2. Build a new LoRAX container image compatible with SageMaker, push it to Amazon ECR
  3. Download adapters from the HuggingFace Hub and upload them to S3
  4. Deploy the extended LoRAX container to SageMaker
  5. Compare outputs of the base model and the adapter model
  6. Benchmark our deployed endpoint under different traffic patterns - same adapter, and random access to many adapters

You can find the full example notebook on Github here.

Setup environment

We’ll use the SageMaker python SDK to deploy LoRAX. Let’s install it first:

!pip install -U boto3 sagemaker 

If you are running on your local environment and not SageMaker Studio/Notebooks, you will need access to an IAM role with the required permissions for SageMaker. If this is your case, follow these instructions before proceeding:

import sagemaker
import boto3
sess = sagemaker.Session()

sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
# set to default bucket if a bucket name is not given
sagemaker_session_bucket = sess.default_bucket()

try:
role = sagemaker.get_execution_role()
except ValueError:
iam = boto3.client('iam')
role = iam.get_role(RoleName='<REPLACE SAGEMAKER ROLE NAME HERE>')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

Define new LoRAX container image with SageMaker entrypoint

The original TGI server that LoRAX forked from was already compatible with SageMaker Hosting (providing the required /invocations and /ping paths), which makes our task pretty simple. We just need to make sure that the server is launched on port 8080, via the container’s ENTRYPOINT instruction. For reference, here are the basic interfaces required to adapt any container for deployment on Sagemaker Hosting.

LoRAX’s interface is very plug-and-play, as you don’t need to write any inference code yourself. The server’s features are exposed to the user via environment variables and CLI parameters, which you can find here.

Let’s replicate TGI’s sagemaker_entrypoint.sh script, adding an environment variable that is needed to enable dynamic loading of adapters from S3 (not documented at the time of writing).

First we write the entry point to a bash script:

%%bash
mkdir sagemaker_lorax/

cat <<EOF > sagemaker_lorax/sagemaker_entrypoint.sh
#!/bin/bash

if [[ -z "\${HF_MODEL_ID}" ]]; then
echo "HF_MODEL_ID must be set"
exit 1
fi
export MODEL_ID="\${HF_MODEL_ID}"

if [[ -n "\${HF_MODEL_REVISION}" ]]; then
export REVISION="\${HF_MODEL_REVISION}"
fi

if [[ -n "\${SM_NUM_GPUS}" ]]; then
export NUM_SHARD="\${SM_NUM_GPUS}"
fi

if [[ -n "\${HF_MODEL_QUANTIZE}" ]]; then
export QUANTIZE="\${HF_MODEL_QUANTIZE}"
fi

if [[ -n "\${HF_MODEL_TRUST_REMOTE_CODE}" ]]; then
export TRUST_REMOTE_CODE="\${HF_MODEL_TRUST_REMOTE_CODE}"
fi

if [[ -z "\${ADAPTER_BUCKET}" ]]; then
echo "Warning: ADAPTER_BUCKET not set. Only able to load local or HuggingFace Hub models."
else
export PREDIBASE_MODEL_BUCKET="\${ADAPTER_BUCKET}"
fi

lorax-launcher --port 8080
EOF

And create the new Dockerfile:

%%bash
cat <<EOF > sagemaker_lorax/Dockerfile
FROM ghcr.io/predibase/lorax:latest

COPY sagemaker_entrypoint.sh entrypoint.sh
RUN chmod +x entrypoint.sh

ENTRYPOINT ["./entrypoint.sh"]
EOF

Build container and push to ECR

Now we build the container and push it to our own Amazon Elastic Container Registry repo. SageMaker will pull and run this container when we deploy a real-time endpoint. Note SageMaker supports private Docker registries as well.

Run the following, replacing region with the AWS region you want to deploy your endpoint on.

%%bash 
algorithm_name="lorax" # name of your algorithm
tag="sagemaker"
region="us-east-1"

account=$(aws sts get-caller-identity --query Account --output text)

image_uri="${account}.dkr.ecr.${region}.amazonaws.com/${algorithm_name}:${tag}"

# If the repository doesn't exist in ECR, create it.
aws ecr describe-repositories --repository-names "${algorithm_name}" > /dev/null 2>&1

if [ $? -ne 0 ]
then
aws ecr create-repository --repository-name "${algorithm_name}" --region $region > /dev/null
fi

cd sagemaker_lorax/ && docker build -t ${algorithm_name}:${tag} .

# Authenticate Docker to an Amazon ECR registry
aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin ${account}.dkr.ecr.${region}.amazonaws.com

# Tag the image
docker tag ${algorithm_name}:${tag} ${image_uri}

# Push the image to the repository
docker push ${image_uri}

# Save image name to tmp file to use when deploying endpoint
echo $image_uri > /tmp/image_uri

Download adapter from HuggingFace Hub and push it to S3

We are going to simulate storing our adapter weights on S3, and having LoRAX load them dynamically as we invoke them. This enables most scenarios, including deployment after you’ve finetuned your own adapter and pushed it to S3, as well as securing deployments with no internet access inside your VPC, as detailed in this blog post.

We first download an adapter trained with Mistral Instruct v0.1 as the base model to a local directory. This particular adapter was trained on GSM8K, a grade school math dataset.

!pip install huggingface_hub --quiet

from pathlib import Path
from huggingface_hub import snapshot_download

HF_MODEL_ID = "vineetsharma/qlora-adapter-Mistral-7B-Instruct-v0.1-gsm8k"
# create model dir
model_dir = Path('mistral-adapter')
model_dir.mkdir(exist_ok=True)

# Download model from Hugging Face into model_dir
snapshot_download(
HF_MODEL_ID,
local_dir=str(model_dir), # download to model dir
local_dir_use_symlinks=False, # use no symlinks to save disk space
revision="main", # use a specific revision, e.g. refs/pr/21
)

Then we copy this same adapter n_adapters times to different S3 prefixes in our SageMaker session bucket, simulating a large number of adapters we want to serve on the same endpoint and underlying GPU.

import os

s3 = boto3.client('s3')

def upload_folder_to_s3(local_path, s3_bucket, s3_prefix):
for root, dirs, files in os.walk(local_path):
for file in files:
local_file_path = os.path.join(root, file)
s3_object_key = os.path.join(s3_prefix, os.path.relpath(local_file_path, local_path))
s3.upload_file(local_file_path, s3_bucket, s3_object_key)

# Upload the folder n_adapters times under different prefixes
n_adapters=50
base_prefix = 'lorax/mistral-adapters'
for i in range(1, n_adapters+1):
prefix = f'{base_prefix}/{i}'
upload_folder_to_s3(model_dir, sagemaker_session_bucket, prefix)
print(f'Uploaded folder to S3 with prefix: {prefix}')

Deploy SageMaker endpoint

Finally, we deploy the endpoint, pointing to our SageMaker session bucket as the ADAPTER_BUCKET env variable, which enables downloading adapters from S3. We deploy on a single g5.xlarge instance, backed by a 24GB A10G GPU.

from sagemaker import Model
import json

from sagemaker import Model
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# Retrieve image_uri from tmp file
image_uri = !cat /tmp/image_uri
# Increased health check timeout to give time for model download
health_check_timeout = 800
number_of_gpu = 1
instance_type = "ml.g5.xlarge"
endpoint_name = 'sm-lorax'

# Model and Endpoint configuration parameters
config = {
'HF_MODEL_ID': "mistralai/Mistral-7B-Instruct-v0.1", # model_id from hf.co/models
'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
'MAX_INPUT_LENGTH': json.dumps(1024), # Max length of input text
'MAX_TOTAL_TOKENS': json.dumps(4096), # Max length of the generation (including input text)
'ADAPTER_BUCKET': sagemaker_session_bucket,
}

lorax_model = Model(
image_uri=image_uri[0],
role=role,
env=config
)

lorax_predictor = lorax_model.deploy(
endpoint_name=endpoint_name,
initial_instance_count=1,
instance_type=instance_type,
container_startup_health_check_timeout=health_check_timeout,
serializer=JSONSerializer(),
deserializer=JSONDeserializer()
)

Invoke base model and adapter, compare outputs

Once the endpoint deployment process is complete, we can invoke the base Mistral model, as well as any of the adapters in our bucket! LoRAX will take care of downloading them, continuously batch requests for different adapters, and manage DRAM and RAM by loading/offloading adapters.

Let’s inspect the difference between the base model’s response and the adapter’s response:

prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'

payload_base = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 64,
}
}

payload_adapter = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 64,
"adapter_id": f'{base_prefix}/10',
"adapter_source": "s3"
}
}

response_base = lorax_predictor.predict(payload_base)
response_adapter = lorax_predictor.predict(payload_adapter)

print(f'Base model output:\n-------------\n {response_base[0]["generated_text"]}')
print(f'Adapter output:\n-------------\n {response_adapter[0]["generated_text"]}')

Result:

Base model output:
-------------
Let's break down the problem:

1. In April, Natalia sold clips to 48 of her friends.
2. In May, she sold half as many clips as in April, which means she sold 48/2 = 24 clips in May.

Adapter output:
-------------
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
In total, Natalia sold 48 + 24 = <<48+24=72>>72 clips in April and May.
#### 72

Nice! Clearly the fine-tuned adapter is much more straight to the point in solving this problem. In fact, the base model wasn’t even able to give us a final response within the 64 max_new_tokens budget we gave it.

Benchmark single adapter vs. random access to adapters

As a final step, let’s check out what performance looks like (request latency and throughput) when we call the same adapter, vs. any one of the n_adapters at random, in parallel.

First, we individually call each of the adapters in sequence, to make sure they are previously downloaded to the endpoint instance’s disk. We want to exclude S3 download latency from the benchmark metrics.

for i in range(1,n_adapters+1):
payload_adapter = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 64,
"adapter_id": f'{base_prefix}/{i}',
"adapter_source": "s3"
}
}
print(lorax_predictor.predict(payload_adapter))

Now we’re ready to benchmark. First, we call a single_adapter total_requests times from num_threads clients (threads) in parallel. Then, we repeat the experiment, but this time each client can call any of the adapters at random. We make sure each adapter is called total_requests // num_adapters times. To make this benchmark accurate, we should measure the number of output tokens/request, as even with greedy decoding and default temperature set to 0, deterministic outputs are not a guarantee. In this case, since all adapters are the same, the input prompt is rather straightforward, our bound for max_new_tokens is low and I’m lazy 🦥, we’ll skip doing that (you can check on the previous cell’s outputs that all generated sequences are the same).

import threading
import time
import random


# Configuration
total_requests = 300
num_adapters = 50
num_threads = 20 # Adjust based on your system capabilities


# Shared counters and lock
adapter_counters = [total_requests // num_adapters] * num_adapters
counters_lock = threading.Lock()

def invoke_adapter(aggregate_latency, single_adapter=False):
global total_requests
latencies = []
while True:
with counters_lock:
if single_adapter:
adapter_id = 1
if total_requests > 0:
total_requests -= 1
else:
break
else:
# Find an adapter that still needs to be called
remaining_adapters = [i for i, count in enumerate(adapter_counters) if count > 0]
if not remaining_adapters:
break
adapter_id = random.choice(remaining_adapters) + 1
adapter_counters[adapter_id - 1] -= 1

prompt = '[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]'
payload_adapter = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 64,
"adapter_id": f'{base_prefix}/{adapter_id}',
"adapter_source": "s3"
}
}
start_time = time.time()
response_adapter = lorax_predictor.predict(payload_adapter)
latency = time.time() - start_time
latencies.append(latency)

aggregate_latency.extend(latencies)

def benchmark_scenario(single_adapter=False):
threads = []
all_latencies = []
start_time = time.time()

for _ in range(num_threads):
thread_latencies = []
all_latencies.append(thread_latencies)
thread = threading.Thread(target=invoke_adapter, args=(thread_latencies, single_adapter))
threads.append(thread)
thread.start()

for thread in threads:
thread.join()

total_latency = sum([sum(latencies) for latencies in all_latencies])
total_requests_made = sum([len(latencies) for latencies in all_latencies])
average_latency = total_latency / total_requests_made
throughput = total_requests_made / (time.time() - start_time)

print(f"Total Time: {time.time() - start_time}s")
print(f"Average Latency: {average_latency}s")
print(f"Throughput: {throughput} requests/s")

# Run benchmarks
print("Benchmarking: Single Adapter Multiple Times")
benchmark_scenario(single_adapter=True)

print("\nBenchmarking: Multiple Adapters with Random Access")
benchmark_scenario()

And the result:

Benchmarking: Single Adapter Multiple Times
Total Time: 42.34047770500183s
Average Latency: 2.7960794830322264s
Throughput: 7.085418524516227 requests/s

Benchmarking: Multiple Adapters with Random Access
Total Time: 42.60287928581238s
Average Latency: 2.8158644016583763s
Throughput: 7.041777732977514 requests/s

Amazing! It seems we got virtually the same performance from invoking 20 different, random adapters in parallel (from a total set of 50) as we did invoking a single adapter. All of this on a single, cheap g5 instance, backed by a 24GB A10G GPU 🤯.

If all adapters were fine-tuned on different data instead of a replicated adapter, we would effectively now be serving 50 different models on a single A10G.

Further performance considerations — a rant

You can explore pushing this experiment to its limits with greater concurrency, larger models (parallelism across GPUs might be needed) and instance sizes, more adapters, among other variations, to find the optimal configuration for your traffic pattern.

The above is very broad, however. So, based on the state-of-the-art in autoregressive, attention-based model serving, I’ll finish with some thoughts on the important factors to consider regarding the performance of multi-adapter serving systems. I hope these might be of use to whoever finds themselves grappling with taking the next step, and actually optimizing such a deployment beyond my “naive” example.

Loading more adapters to GPU for concurrent execution means you will have less DRAM available to store the KV cache, which can grow quickly for large batches and/or long sequences (see this excelent blog post by my colleague Pierre Lienhart to dive deeper into KV caching, a standard in LLM serving). Having enough DRAM for batching is extremely important for throughput optimization and operating at an optimal cost/token point, where latency degradation is minimal, and hardware utilization is maximized. This point corresponds to the boundary between memory and compute bound regimes (more on this topic here), and it is where you get the maximum “bang for your buck”. Although we did not observe performance degradation with 20 concurrent clients and the particularly small input/output payload used in this example, at higher concurrency and larger sequence lengths we would be limited to a smaller batch size than if serving the base model alone. It’s a clear tradeoff between bigger batches = better resource utilization, and higher diversity of models to be served concurrently. As with any performance optimization endeavour, the optimal configuration will be defined by the specific requirements of your workload, and your envisioned end-user experience. I will leave this exercise for a future blog post.

To bring this rant to a close, optimizing inference performance is a tough nut to crack as-is for a single LLM, where the model size is constant in GPU DRAM. Multi-adapter serving optimization will require accounting for this extra dimension — the dependency between batch size/concurrency and the extra MBs of adapter weights needed on DRAM to process that heterogeneous batch. There will surely be other factors to consider that I did not cover; happy to know the reader’s thoughts on the matter.

Conclusion

In this blog post, we explored the potential of LoRA serving, and how it has been recently realized by swift technical advancements in the field.

We deployed and benchmarked LoRAX Server on Amazon Sagemaker, observing similar performance when comparing same-adapter traffic, and random invocation of different adapters.

Finally, we explored base considerations for those aiming to deeply optimize a multi-adapter deployment.

This practical example shows that you can unlock the power and economic viability of smaller language models beyond training, into the production deployment world.

Special thanks to Travis Addair at Predibase for help on some (minor) LoRAX S3 integration issues I ran into while drafting this blog.

Disclaimer: This blog post and the accompanying code was written by João Moura, an AI Solutions Architect working for AWS. All opinions shared are the author’s personal opinions, and may not represent the official view of AWS.

--

--