LLM Model Sharding

Sharath S Hebbar
3 min readFeb 18, 2024

GitHub LinkedIn Medium Portfolio Substack

Introduction

Large Language Models (LLMs) represent a significant advancement in artificial intelligence and natural language processing.

Large Language Models

Models such as OpenAI’s GPT (Generative Pre-trained Transformer) series, Google’s Gemini, PaLM, T5, and many such open-source models have achieved remarkable capabilities in understanding and generating human-like text.

However, as these models grow larger to improve performance, they also pose challenges in terms of scalability, resource requirements, and ethical considerations.

A major challenge is using such models. Leave alone using the LLM in Colab, Kaggle notebook, or locally with less amount of RAM, even loading such huge models need high RAM which is not a feasible solution.

So one such solution will be model sharding which converts the huge models into smaller chunks which in turn takes less time and consumes less hardware for loading such huge models.

Here we will discuss model sharding using Open Source LLM Mistral 7B freely hosted on HuggingFace Platform.

Hugging Face

Code

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator, load_checkpoint_and_dispatch
model_name = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)

accelerator = Accelerator()

accelerator.save_model(
model=model,
save_directory=save_directory,
max_shard_size="200MB"
)

device_map={"":'cpu'}

model = load_checkpoint_and_dispatch(
model,
checkpoint="/content/model/",
device_map=device_map,
no_split_module_classes=["Block"]
)
new_model = "<Name of the model>"
HF_TOKEN = "<Your HF Token>"

tokenizer.push_to_hub(
new_model,
token=HF_TOKEN
)

model.push_to_hub(
new_model,
token=HF_TOKEN
)

Loading Sharded Model

Original Model

The original model took 16GB RAM to load the full model in 16-bit floating point

%%time
model_name = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
CPU times: user 36.8 s, sys: 48.5 s, total: 1min 25s
Wall time: 3min 30s
%%time
model_name = "Sharathhebbar24/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
CPU times: user 23 s, sys: 48.7 s, total: 1min 11s
Wall time: 1min 49s
Sharded

The sharded model took 3GB RAM to load the full model in a 16-bit floating point.

References

  1. HF Docs: https://huggingface.co/docs/transformers/en/big_models
  2. Using Accelerate: https://huggingface.co/docs/transformers/en/main_classes/model#large-model-loading
  3. Medium: https://medium.com/@sharathhebbar24/llm-model-sharding-55102ecb1823
  4. Github: https://github.com/SharathHebbar/Model-Sharding
  5. Reference: https://medium.com/@jain.sm/sharding-large-models-for-parallel-inference-ee19844cc44#:~:text=Memory%20Efficiency%3A%20Sharding%20enables%20running,parts%2C%20reducing%20memory%20requirements%20significantly.

--

--