Fine-Tuning Stable Diffusion 3 with Multiple Prompts on a GPU with 16GB VRAM

Filippo Santiano
3 min readJul 10, 2024

--

Introduction

In my previous post, I explained how to implement quantization to fine-tune Stable Diffusion 3 Medium (SD3-M) with only 16GB VRAM. This technique lowers training costs and makes fine-tuning more accessible to those looking to personalise their text-to-image models. If you followed those instructions, you would have successfully trained SD3-M using a single prompt for a specified number of images.

The Need for Unique Prompts

What if you want to train your model using a unique prompt for each image? This approach is useful when you want to add additional context to your prompts. For example, instead of “A cat called Lily”, you could write “A cat called Lily, sitting on the grass outside”. Providing additional details gives the model a better understanding of what the image is showing, resulting in more accurate training.

VRAM Usage

Unfortunately, if you use multiple prompt-image pairs with the training script from the previous post, VRAM usage will exceed 16GB. This happens because each embedded prompt is stored in VRAM by default.

To address this, the script has been updated to store embedded prompts in RAM and only the prompts needed for the current training batch are moved to VRAM. For example, if --train_batch_size is equal to five, only five embedded prompts are stored in VRAM and the rest are stored in RAM. This ensures we can fine-tune SD3-M with different prompts and images, without exceeding 16GB VRAM.

Requirements

All the necessary code can be found here. For the following steps to work, you must have completed up to “Fine-tuning the Model”, from my previous post.

Steps

Organise Your Images and Prompts

Set up your training data directory with the following format:

training_dir/

├── image_01.png
├── image_02.png
├── image_03.png
├── ...
├── image_n.png
└── metadata.jsonl

The metadata.jsonl requires the format below. Note that the keys "file_name" and "caption" should remain unchanged and only their values should differ.

{"file_name" : "image_01.png", "caption": "Prompt for image_01"}
{"file_name" : "image_02.png", "caption": "Prompt for image_02"}
{"file_name" : "image_03.png", "caption": "Prompt for image_03"}
...
{"file_name" : "image_n.png", "caption": "Prompt for image_n"}

New Training Script

I have added a new training script to the repo called train_multiple_prompts.py that includes the necessary changes for this process to work. For reasons mentioned previously, the script stores only train_batch_size number of embedded prompts in VRAM and the rest in RAM.

Fine-tuning the Model

We need to specify MODEL_NAME and OUTPUT_DIR again, but not INSTANCE_DIR. The training dataset will be specified under the flag --dataset_name in our training command.

export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
export OUTPUT_DIR="./fine_tuned_model"

We can now run our training script!

accelerate launch train_multiple_prompts.py \
--dataset_name="path/to/training_images" \
--pretrained_model_name_or_path=${MODEL_NAME} \
--output_dir=${OUTPUT_DIR} \
--mixed_precision="bf16" \
--resolution=512 \
--train_batch_size=4 \
--sample_batch_size=4 \
--gradient_accumulation_steps=3 \
--learning_rate=0.0001 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--weighting_scheme="logit_normal" \
--seed="42" \
--use_8bit_adam \
--gradient_checkpointing \
--prior_generation_precision="bf16" \
--caption_column="caption"

Note: If the total number of prompt-image pairs is not divisible by train_batch_size, the batch size for the last iteration will be equal to the remainder. Despite this, training will still work.

Running Inference

Once training is complete, you can run inference using “Running Inference” from my previous post.

Summary

This post has provided the steps required to fine-tune SD3-M using a different prompt for each image. I hope you found this useful and were able to increase the accuracy of your model outputs whilst keeping VRAM usage low.

--

--