With the recent release of Llama 3, Meta has delivered a game-changing open-source model that combines impressive performance with a compact size. The Llama 3 8B model, in particular, is a true gem — its small footprint and high-quality outputs make it the perfect choice for on-device LLM.
In this blog post, I’ll guide you through the process of fine-tuning the Llama 3 8B model using the incredible MLX-LM library. The MLX team has been working tirelessly to improve this tool, and their efforts have paid off in spades. By the time you finish this tutorial, you’ll be a pro at adapting the Llama 3 8B model to your specific needs, unlocking its full potential for your projects.
Step 1: Preparing the Training Data
First things first, let’s get our hands on some top-notch training data. The glaiveai/glaive-function-calling-v2 dataset is a fantastic resource designed specifically for training language models to handle function calls. We’ve put together a handy script that will convert this dataset into the Llama 3 chat format, or you can download the preprocessed dataset directly from https://huggingface.co/datasets/mzbac/function-calling-llama-3-format-v1.1/tree/main and save it into a “data” folder for easy access during fine-tuning.
Step 2: Installing the MLX-LM Package
Next up, let’s get the mlx-lm package installed. It’s as easy as running:
pip install mlx-lm
This powerful library provides a user-friendly interface for fine-tuning LLMs, taking the hassle out of the process and helping you achieve better results.
Step 3: Creating the LoRA Config
Now, it’s time to set up the LoRA (Low-Rank Adaptation) configuration for fine-tuning the Llama 3 8B model. We’ve made a few key changes to optimize performance and efficiency:
- Using fp16 instead of qlora to avoid potential performance degradation due to quantization and de-quantization.
- Setting lora_layers to 32 and using full linear layers for results that rival full fine-tuning.
Here’s an example of what your lora_config.yaml file might look like:
# The path to the local model directory or Hugging Face repo.
model: "meta-llama/Meta-Llama-3-8B-Instruct"
# Whether or not to train (boolean)
train: true
# Directory with {train, valid, test}.jsonl files
data: "data"
# The PRNG seed
seed: 0
# Number of layers to fine-tune
lora_layers: 32
# Minibatch size.
batch_size: 1
# Iterations to train for.
iters: 6000
# Number of validation batches, -1 uses the entire validation set.
val_batches: 25
# Adam learning rate.
learning_rate: 1e-6
# Number of training steps between loss reporting.
steps_per_report: 10
# Number of training steps between validations.
steps_per_eval: 200
# Load path to resume training with the given adapter weights.
resume_adapter_file: null
# Save/load path for the trained adapter weights.
adapter_path: "adapters"
# Save the model every N iterations.
save_every: 1000
# Evaluate on the test set after training
test: false
# Number of test set batches, -1 uses the entire test set.
test_batches: 100
# Maximum sequence length.
max_seq_length: 8192
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: true
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.
# These will be applied for the last lora_layers
keys: ['mlp.gate_proj', 'mlp.down_proj', 'self_attn.q_proj', 'mlp.up_proj', 'self_attn.o_proj','self_attn.v_proj', 'self_attn.k_proj']
rank: 128
alpha: 256
scale: 10.0
dropout: 0.05
# Schedule can only be specified in a config file, uncomment to use.
# lr_schedule:
# name: cosine_decay
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-6, 1000, 1e-7] # passed to scheduler
Step 4: Running the Fine-Tuning Process
With your data and configuration ready to go, it’s time for the main event — fine-tuning the Llama 3 8B model! Just run:
mlx_lm.lora --config lora_config.yaml
Then sit back and let MLX-LM work its magic.
Step 5 (Optional): Fusing the Trained Adapters
If you want to distribute your fine-tuned model, you can easily fuse the trained adapters with the original Llama 3 8B model in HF format:
mlx_lm.fuse --model meta-llama/Meta-Llama-3-8B-Instruct
PS: The fine-tuned model can be found here