Finetuning LLM efficiently: Part 1 — Simple fixes to the Dataloader

Timothy Lim
3 min readSep 15, 2023

--

In this blog series, we will delve into specific tips and tricks to fine-tuning LLMs efficiently. Finetuning LLM for custom applications has garnered a lot of interests from businesses and research due to its potential to transform and improve a lot of our current tools that we are familiar with.

One prevalent challenge when it comes to training or fine-tuning LLMs is the considerable time it can consume. I hope to share some simple fixes or rather things to be aware about your training loop as you go about finetuning your own custom LLM.

Photo by Chris Liverani on Unsplash

Assumption of reader: Familiar with LLM, I recommend this blog to get up to speed if you are not.

Fix Dataloader Collate function

In a few open-source repository, the Dataloader is set up in a way that uses a default collator that pads the every sequence up to the predefined maximum context length. For example, the training loop written by Meta for llama-recipes is as such:

from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
LlamaConfig,
default_data_collator,
)
.
.
.
.
# Create DataLoaders for the training and validation dataset
train_dataloader = torch.utils.data.DataLoader(
dataset_train,
batch_size=train_config.batch_size_training,
num_workers=train_config.num_workers_dataloader,
pin_memory=True,
sampler=train_sampler if train_sampler else None,
drop_last=True,
collate_fn=default_data_collator,
)

The default_data_collator is actually a bad idea for finetuning transformers as it is pointless to pad every sequence you have up to the maximum length. It ends up being a waste of compute and time.

The better approach is to pad your sequence up to the maximum length of the current batch sequences so that there is less compute necessary. Pad tokens does not contribute to the loss so it is unnecessary compute that we should try to avoid whenever possible.

To illustrate this, imagine a batch of 4 sequences (tokens):

# batch of 4 sequences of token
[1,2,3]
[1,2,3,4,5,6,7,8,9]
[1,2,3,4,5]
[1]

Let’s assume that your maximum length is 15 tokens and the pad token number is set to be 0:

# Using default collator:

[1,2,3,0,0,0,0,0,0,0,0,0,0,0,0]
[1,2,3,4,5,6,7,8,9,0,0,0,0,0,0]
[1,2,3,4,5,0,0,0,0,0,0,0,0,0,0]
[1,0,0,0,0,0,0,0,0,0,0,0,0,0,0]

# Using custom collator:

[1,2,3,0,0,0,0,0,0]
[1,2,3,4,5,6,7,8,9]
[1,2,3,4,5,0,0,0,0]
[1,0,0,0,0,0,0,0,0]

As we can see, padding up to the maximum length of the current batch will require less compute as the batch is of a shorter length.

We can write the code for our custom collator as such for llama-recipes:

from torch.nn.utils.rnn import pad_sequence
.
.
.
.
.
def collate_fn(batch):
input_ids = [item["input_ids"] for item in batch]
labels = [item["labels"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]

padded_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0)
padded_labels = pad_sequence(labels, batch_first=True, padding_value=-100)
padded_attention_mask = pad_sequence(
attention_mask, batch_first=True, padding_value=0
)

return {
"input_ids": padded_input_ids,
"labels": padded_labels,
"attention_mask": padded_attention_mask,
}

One thing to be aware about on why this will work very well in reducing training time is because your custom dataset will most probably have a lot of sequences that are spreaded out under the maximum context lengt.

For example, imagine that you are finetuning a custom LLM under a maximum length of 1024 for these dataset:

  1. Alpaca (Instruction Tuning)

2. Subset of OASST (Conversational Data)

You can observe that the most sequences are actually quite away from our predefined maximum length (even if you set it to 512).

Lastly, to make the training even more efficient along with this custom collator, you can pre-sort your data according to the sequence length to keep the length in batches similar. By sorting, your training time will be further reduced as the amount of overall padding required for your dataset will drop by a good amount.

Do not be surprise if you will get a 10x speed-up finetuning your custom dataset after changing the way you collate as I have!

Continue to Part 2 of this series

--

--