Fine-Tuning Stable Diffusion With Validation

damian0815
7 min readApr 10, 2023

--

Damian Stewart / @damian0815

This article is part of a series on fine-tuning Stable Diffusion models. See also The LR Finder Method for Stable Diffusion.

“[fine-tuning without validation] is like trying to hit a pinata blindfolded”
- @watusi, EveryDream2 Discord

When you’re training a machine learning model, you need some way of knowing when it’s finished. Normally you do this by checking graphs of the model’s loss value, which is a way of measuring how well the model has incorporated the training data. The loss starts high, and as training progresses it decreases:

a nicely decreasing loss graph
As you train for more steps, the loss decreases.

During training, for each training step the trainer tries to modify the model so that the loss score decreases. The amount that it changes is determined by the learning rate (LR). Generally, a higher learning rate will train faster and a lower learning rate will train slower — but if you set the learning rate too high, your training will not be effective, and if you set it too low, you’re going to waste a lot more energy (carbon) than you need to to train the model effectively.

The Problem of Overfitting

Unfortunately, AIs love to cheat. If it can find any way at all to make the loss decrease, then it will use this — even if it the results end up looking worse than the person training the model would like. For Stable Diffusion, here’s a subtle example from a recent training test from a dataset generously donated to me by one of the EveryDream2 discord users, which includes these images, captioned “ancient temple” and “ancient galley”:

two illustrations of an “ancient temple” and one of an “ancient galley”

And after what turns out to be too many epochs at a too high learning rate, the model I was training generated this image for the prompt “ancient temple”:

Do you notice the weird curved shape of the steps leading in to the temple? This only started appearing after a large number of training epochs, and it indicates that the model has overfit the training data:

It has learnt that if it draws shapes that are curved like the sail of the ship, then it can cheat the training process — because the positive score it gets for getting the sails right on the “ancient galley” images outweighs the negative score it gets for messing up the “ancient temple” steps.

This happens due to quirks in the dataset that aren’t apparent or important to our eyes, but which turn out to be loopholes or shortcuts in the maths. In technical terms, the model has overfit to the training data. It has learnt things that we don’t actually want it to learn, just because they allow it to cheat and therefore get a better score.

How Do We Prevent Overfitting?

Overfitting happens when you train a model too much, whether that’s for too many epochs, or with a learning rate that is too high, or both. To prevent it, you need to know when to stop. One of the best ways to do this is using validation. Validation works by withholding a portion of your dataset images from the training process, which is used to check that the model isn’t cheating.

Recall our original training loss graph:

a nicely decreasing loss graph

If we add the validation loss to this graph and plot it alongside the training loss, the curves together look like this:

As long as training is progressing well, the validation loss and the training loss follow each other closely. However, as the model begins to overfit, the validation loss levels off and starts to rise.

From the example above with the “ancient galley” and “ancient temple” images: suppose the validation also has images of boats, but they have different shaped sails. When the model starts to overfit on the “ancient galley” sail in the training set, the loss score on the validation set will start to get worse, because it’s now trying to draw the wrong kind of sail. As the model overfits more and more the validation score will get worse and the curve will start to point upwards, rather than following the training curve further down. When this starts to happen, it’s time to stop training.

The best model will be found at around the bottom of the validation loss curve, which would put it somewhere between these green dotted lines:

Note that this can only work if the validation dataset is removed from the training data entirely. If there is any overlap between the validation images and the training images, validation isn’t able to check if the model is learning things that we don’t really want it to learn. In practise, you will want to use about 15% of your dataset for validation.

Validating Latent Diffusion Models

Because of the way latent diffusion models like Stable Diffusion function (they rely heavily on noise or random number generation), the default loss curves don’t tend to trace a smooth line. Instead, they typically look more like this:

To get around this, the validator in EveryDream2 produces a special loss/val curve that you can turn on by following the instructions here. This will get you a nice smooth curve that should look more like this:

Although there’s a brief bump upwards at around step 180, the curve generally traces a smooth line downwards over the first 300 steps, before slowing down and stagnating over steps 350–500, and then beginning to clearly overfit after step 500:

The best model is likely to be found somewhere roughly between step 350 and step 450, in the valley of the validation curve:

Using validation means you get to skip the guesswork involved in deciding whether a model is ready. Of course, you should always try out using your model in your Stable Diffusion web UI of choice (I recommend InvokeAI — but of course I would, because I contribute code to it). Having validation graphs nevertheless gives you an independent measure of how well trained your model is.

What if My Model Isn’t Ready Yet?

You may find that the model still isn’t trained to a degree you’re happy with by the time your val graph starts pitching upwards again. In this situation, you have two main options:

1. Lower your learning rate and train again.

In my experience, the final quality of a model is directly related to the number of training steps it took to get there: as long as you’re not overfitting, more steps makes a better model. If you’re reaching the valley of the validation curve too early, try lowering the LR and training again. The best strategy is in fact to decide how many epochs you are willing to wait and then optimize your learning rate so that the model reaches the valley closer to the end of training than to the start.

2. Continue training and allow the model to overfit.

If you don’t mind losing “flexibility”, you can just continue training. The model will continue to look more and more like the training data, but it will become increasingly difficult to get it to do anything else. If you keep doing this for too long, the model will learn how to copy your training images exactly — but by this stage, it will struggle to draw anything else. Effectively, you will have nerfed the magical drawing abilities of Stable Diffusion, and what you’ll have instead is a very very large and awkward ZIP file. You can see this happening on this very large image (click “download” to see it): https://huggingface.co/damian0815/pashahlis-val-test-1e-6-ep30/blob/main/grates/pashahlis-val-test_as-received_lr1e-6-768x768.jpg — note how the “ancient temple” columns have lost all of the variety they had from the four different seeds by epoch 50, and from then on they produce very samey images.

How can I Skip Even More Guesswork?

Every dataset needs a different learning rate. To avoid a lot of guesswork you can quickly find what is likely to be a good learning rate using the “LR finder” method. Read more about that in my article on using the LR Finder method for Stable Diffusion.

--

--