Fine-tune Llama3 with function calling via MLX-LM

Anchen
3 min readApr 27, 2024

--

With the recent release of Llama 3, Meta has delivered a game-changing open-source model that combines impressive performance with a compact size. The Llama 3 8B model, in particular, is a true gem — its small footprint and high-quality outputs make it the perfect choice for on-device LLM.

In this blog post, I’ll guide you through the process of fine-tuning the Llama 3 8B model using the incredible MLX-LM library. The MLX team has been working tirelessly to improve this tool, and their efforts have paid off in spades. By the time you finish this tutorial, you’ll be a pro at adapting the Llama 3 8B model to your specific needs, unlocking its full potential for your projects.

Step 1: Preparing the Training Data

First things first, let’s get our hands on some top-notch training data. The glaiveai/glaive-function-calling-v2 dataset is a fantastic resource designed specifically for training language models to handle function calls. We’ve put together a handy script that will convert this dataset into the Llama 3 chat format, or you can download the preprocessed dataset directly from https://huggingface.co/datasets/mzbac/function-calling-llama-3-format-v1.1/tree/main and save it into a “data” folder for easy access during fine-tuning.

Step 2: Installing the MLX-LM Package

Next up, let’s get the mlx-lm package installed. It’s as easy as running:

pip install mlx-lm

This powerful library provides a user-friendly interface for fine-tuning LLMs, taking the hassle out of the process and helping you achieve better results.

Step 3: Creating the LoRA Config

Now, it’s time to set up the LoRA (Low-Rank Adaptation) configuration for fine-tuning the Llama 3 8B model. We’ve made a few key changes to optimize performance and efficiency:

  1. Using fp16 instead of qlora to avoid potential performance degradation due to quantization and de-quantization.
  2. Setting lora_layers to 32 and using full linear layers for results that rival full fine-tuning.

Here’s an example of what your lora_config.yaml file might look like:

# The path to the local model directory or Hugging Face repo.
model: "meta-llama/Meta-Llama-3-8B-Instruct"
# Whether or not to train (boolean)
train: true

# Directory with {train, valid, test}.jsonl files
data: "data"

# The PRNG seed
seed: 0

# Number of layers to fine-tune
lora_layers: 32

# Minibatch size.
batch_size: 1

# Iterations to train for.
iters: 6000

# Number of validation batches, -1 uses the entire validation set.
val_batches: 25

# Adam learning rate.
learning_rate: 1e-6

# Number of training steps between loss reporting.
steps_per_report: 10

# Number of training steps between validations.
steps_per_eval: 200

# Load path to resume training with the given adapter weights.
resume_adapter_file: null

# Save/load path for the trained adapter weights.
adapter_path: "adapters"

# Save the model every N iterations.
save_every: 1000

# Evaluate on the test set after training
test: false

# Number of test set batches, -1 uses the entire test set.
test_batches: 100

# Maximum sequence length.
max_seq_length: 8192

# Use gradient checkpointing to reduce memory use.
grad_checkpoint: true

# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.
# These will be applied for the last lora_layers
keys: ['mlp.gate_proj', 'mlp.down_proj', 'self_attn.q_proj', 'mlp.up_proj', 'self_attn.o_proj','self_attn.v_proj', 'self_attn.k_proj']
rank: 128
alpha: 256
scale: 10.0
dropout: 0.05

# Schedule can only be specified in a config file, uncomment to use.
# lr_schedule:
# name: cosine_decay
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-6, 1000, 1e-7] # passed to scheduler

Step 4: Running the Fine-Tuning Process

With your data and configuration ready to go, it’s time for the main event — fine-tuning the Llama 3 8B model! Just run:

mlx_lm.lora --config lora_config.yaml

Then sit back and let MLX-LM work its magic.

Step 5 (Optional): Fusing the Trained Adapters

If you want to distribute your fine-tuned model, you can easily fuse the trained adapters with the original Llama 3 8B model in HF format:

mlx_lm.fuse --model meta-llama/Meta-Llama-3-8B-Instruct

PS: The fine-tuned model can be found here

--

--