Sharding Large models for parallel inference

shashank Jain
3 min readJul 21, 2023

--

Introduction
With the advent of deep learning and the development of increasingly powerful models, the size of pre-trained language models has grown significantly. While these models have shown remarkable performance in various natural language processing (NLP) tasks, their sheer size poses challenges for inference on resource-constrained devices and large-scale distributed systems. To tackle these challenges, sharding, or splitting large models into smaller pieces, has emerged as a promising approach for achieving more efficient and faster distributed inference.
In this blog post, we will delve into the concept of sharding large models, exploring its benefits, use cases, and implementation details. We will also discuss popular libraries and tools, such as ‘accelerate’, that facilitate sharding and make it easier to perform distributed inference.

Understanding Sharding Large Models
Sharding, in the context of large models, refers to dividing the model into smaller pieces or shards. Each shard is a self-contained and smaller part of the original model. The sharding process aims to exploit parallelism effectively, allowing each shard to be processed independently across different devices or processors, resulting in faster and more efficient inference.

Benefits of Sharding

Memory Efficiency: Sharding enables running large models on devices with limited memory. Instead of loading the entire model into memory, sharding allows loading and processing only the necessary parts, reducing memory requirements significantly.

Faster Inference: By distributing the computation across multiple devices, sharding helps achieve parallelism, resulting in faster inference times. This is particularly beneficial when dealing with massive models that would otherwise be slow to run on a single device.

Scalability: Sharding facilitates the deployment of large models on distributed systems with multiple GPUs or even across a cluster, making it feasible to handle massive workloads and larger-scale tasks.

Distributed Inference: Sharding is essential for large-scale distributed systems where the processing power is distributed across multiple nodes or GPUs. This ensures efficient utilization of computational resources and minimizes communication overhead.

Implementing Sharding with ‘Accelerate’
‘Accelerate’ is a powerful library that simplifies the process of sharding large models for distributed inference. Here’s how you can implement sharding using ‘accelerate’:

Install ‘Accelerate’: First, you need to install the ‘accelerate’ library along with other required dependencies using pip.
Load the Model and Tokenizer: Load the pre-trained language model and tokenizer using the ‘transformers’ library. Choose a model that fits your specific use case.
Shard the Model: Use the ‘accelerate’ library to shard the model into smaller pieces. This step optimizes the model for distributed inference.
Save the Sharded Model: Save the sharded model to a specific directory. This process generates multiple shards that can be efficiently utilized in a distributed setting.
Load and Dispatch the Model: Load the sharded model using ‘accelerate’ and dispatch it to the appropriate device, such as CPU or multiple GPUs, based on your hardware setup.

Here is the code
pip install sentencepiece accelerate

from transformers import T5Tokenizer, T5ForConditionalGeneration,AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(‘declare-lab/flan-alpaca-xl’)
model = T5ForConditionalGeneration.from_pretrained(‘declare-lab/flan-alpaca-xl’)

from accelerate import Accelerator
save_directory=”/content/model”
accelerator = Accelerator()

# Here i am taking a Flan T5 XL fine tuned model (around 10GB) and creating shards of 2 GB each
accelerator.save_model(model=model, save_directory=save_directory,max_shard_size=”2GB”)

from accelerate import load_checkpoint_and_dispatch

# choosing cpu as the device. I have 7 cores , so will spread the model shards across those.

device_map={‘’:’cpu’}

model = load_checkpoint_and_dispatch(
model, checkpoint=”/content/model/”, device_map=device_map, no_split_module_classes=[‘Block’]
)

raw_inputs = “tell me abot alpaca”
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors=”pt”)

outputs = model.generate(**inputs, max_new_tokens=100, return_dict_in_generate=True, output_scores=True)

Conclusion
Sharding large language models has become a crucial technique for enabling efficient distributed inference and deploying models on resource-constrained devices. By dividing large models into smaller, manageable pieces, sharding allows us to harness the full potential of deep learning models without compromising on performance or memory requirements.

The ‘accelerate’ library, along with other related tools, streamlines the sharding process, making it easier for developers to implement distributed inference efficiently. As the field of NLP and deep learning continues to advance, sharding will play an increasingly vital role in leveraging the full capabilities of large models in real-world applications.

--

--