Using Stable Diffusion to Improve Image Segmentation Models

Alex Browne
Edge Analytics
Published in
9 min readJan 12, 2023

As ML engineers, the teams at Edge Analytics and Infinity AI are very familiar with the challenges associated with obtaining high quality labeled images for computer vision applications. With the release of generative image models, such as open-source Stable Diffusion from Stability AI, we explored the use of generative models to improve performance on a specific semantic segmentation model.

Stable Diffusion is an impressively powerful text-to-image model released by Stability AI earlier this year. In this blog post, we’ll explore a technique for augmenting training data with Stable Diffusion in order to improve performance on an image segmentation task. This approach is especially powerful in applications where data is limited or would otherwise require tedious human labeling.

In the context of computer vision models, image segmentation refers to splitting an image into two or more components based on its contents. In contrast to “image classification”, the goal of segmentation is to not only identify what an image contains, but which parts of the image correspond to each class.

Specifically, we will be taking a look at the DeepGlobe Road Extraction Dataset, which consists of around 6,000 aerial photographs of rural roads. The task for this dataset is to separate images into two classes: “road” and “background”. The dataset also comes with training labels in the form of mask images, where roads are identified in white and background is in black.

Example data from the DeepGlobe Road Extraction Dataset
Left: input image. Right: label with “roads” in white and “background” in black.

Baseline Training Procedure

Let’s jump into the nitty gritty details. We split the dataset into 70/15/15 train/test/validation, then, as a baseline, trained a basic U-Net model on the full training set of 4,358 images. We trained this model (and later models) for 50 epochs and restored the best weights based on val_loss after training. We won’t show the code here for loading the data and training the baseline model, just so we can get to the interesting part faster. However, the full code is available in this Colab Notebook.

We measured model performance against the held-out test set with two different metrics: binary cross-entropy (also used as the loss function during training) and intersection over union, the latter of which was used to judge winners in the original DeepGlobe Challenge.

Baseline model performance
Baseline model using the full training dataset.

To improve our intuition around what these numbers really mean, let’s look at some examples of model predictions from the held-out test set. It seems the model is doing a pretty good job of finding the roads.

Images showing the model predictions
Model predictions for images in the test set.

Simulating Limited Data

Baseline performance is useful as a point of comparison, but what we’re really interested in for the purposes of this blog post is whether Stable Diffusion can help in data-limited applications — cases where we only have tens or hundreds of samples to train on. Suppose you didn’t have the time or money to capture and tediously label thousands of images. Could Stable Diffusion help you do more with less?

To help answer this question we’re going to simulate a limited dataset by training a model on just 100 images. Then we’ll augment that limited dataset by generating “image variants” with Stable Diffusion. To do that, we’ll first collect performance metrics for a new model trained on just 100 images.

Metrics for a model trained on a limited dataset of just 100 images.

Unsurprisingly, this extreme reduction in dataset size results in a regression on both performance metrics. If we look at some sample predictions, it looks like the model is not doing a very good job at all.

Example predictions for the data-limited model.

Augmenting Data with Stable Diffusion

Now, the fun part! Let’s use Stable Diffusion to generate “variants” for each of the 100 images in our limited training set. Unfortunately, in this case we can’t simply ask Stable Diffusion to generate images from scratch. That’s because for image segmentation training, we need both the input images and the corresponding labels/masks. While Stable Diffusion is a powerful model, it’s not capable of accurately generating both the input images and the mask images (and if it were, we wouldn’t need to train our own U-Net model in the first place!). Instead, we’re going to use a clever trick using the in-painting feature.

The in-painting feature makes it so that Stable Diffusion only replaces or redraws specific parts of an image. For our use case, we’re going to use Stable Diffusion to create image “variants” by only redrawing the “background” part of the image. Since the “road” part of the image remains unchanged, we can re-use the existing mask images in the training dataset as labels! Here’s an example of what a generated image variant looks like:

Left: image mask. Middle: original image. Right: image variant generated by Stable Diffusion.

Here’s the key piece of code we used to generate the image variants (remember, the full code is available on Google Colab). Note that we don’t want Stable Diffusion to generate new roads, because that would make the existing labels wrong. We only want it to generate background scenery. To dissuade the model from adding roads, we included “road” and “highway” in the negative prompt.

# Load and configure Stable Diffusion inpainting model
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

# Create prompts for image generation.
BACKGROUND_PROMPT = "Overhead aerial photograph of a rural area, 8K, high resolution, satellite image, extremely detailed"
# Note the inclusion of "road" and "highway" in the negative prompt. We don't
# want Stable Diffusion to add any more roads to the image because we are not
# changing the mask during training.
BACKGROUND_NEGATIVE_PROMPT = "road, highway, fisheye"

def create_image_variant(image_path, mask_path):
# Use PIL to resize the images to 512x512, the optimal size for Stable
# Diffusion.
init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
# Note we are inverting the mask in this case because we want to re-draw
# everything that is *not* part of the road.
inverted_mask = PIL.ImageOps.invert(mask_image)

# Use Stable Diffusion to re-paint the background scenery.
gen_background_image = pipe(
prompt=BACKGROUND_PROMPT,
negative_prompt=BACKGROUND_NEGATIVE_PROMPT,
image=init_image,
mask_image=inverted_mask
).images[0]

# For some reason, this version of Stable Diffusion is not completely
# respecting the image mask and will distort/repaint portions of the image
# that it is not supposed to. We can workaround this by re-applying the mask
# to the generated image.
# See https://github.com/runwayml/stable-diffusion/issues/5
final_image = PIL.Image.composite(
gen_background_image,
init_image,
inverted_mask.convert("L")
)

return final_image

That’s it! With just a few lines of code we can start augmenting our training dataset with generated images.

Re-Training with Augmented Datasets

With a plan in place, the next step was to train the model again using the new generated image variants. To test the limits of this approach, we tested with 1, 4, and 16 variants per source image and compared the results. Importantly, image variants were only generated from the 100 samples from the limited training set, giving us an accurate view of what might happen if we only had those 100 images to start with. We used the same training procedure as before.

Including the generated image variants measurably improved performance compared to only using the original 100 images. If we look at the example predictions for the 4-variant model, we see a definite improvement over the previous model. These predictions are still not perfect, but they are much better, especially considering we started with the same set of just 100 images.

Example predictions for the 4-variant model.
For comparison: example predictions for the previous data-limited model.

Limitations

Generally speaking, including more variants per image increased performance. However, it seems we are running up against the limits of this approach when generating 16 variants per source image. While the 16-variant model has the best IoU metric, we saw degradation in binary cross-entropy (even worse than the model without any augmented data). Additionally, the loss curves for the 16-variant model show evidence of over-fitting during training. Keep in mind that we didn’t change the “road” part of the image at all for any of the image variants. Given this, one explanation is that the model is over-fitting on either the shape of the road, the individual pixels that make up the road, or both.

Loss curves during training of the 16-variant model.

Of course, none of our augmented datasets were good enough to surpass the baseline model trained on the full dataset of thousands of images. That was never really our goal in the first place and is not terribly surprising. The reality is that not everyone will have the time or money required to capture and label thousands of images. What we’ve shown here is that you can use generated data to meaningfully improve model performance in data-limited applications.

Possible Improvements

There are a number of traditional techniques we could use that might improve model performance, including more basic image augmentation (flipping, zooming, and rotating), hyperparameter optimization, and fine-tuning a larger pre-trained model. These, however, are out of scope of this blog, so we’ll focus on generative image improvements only.

The biggest problem with the approach described here is that because we were not changing the “road” part of the image between variants, the model eventually starts to over-fit. There are a few different ways that this could be addressed:

  1. Use Stable Diffusion to redraw the “road” part of the image as well as the “background” part. We tried doing this, but Stable Diffusion had a hard time generating roads in many cases, possibly because some of the lines in the mask image are very thin (sometimes less than 1 pixel thick after resizing). For other use cases that do not have thin masks, this may be a much easier thing to do.
  2. Break each original image into multiple parts, generate variants for each part, then stitch the different parts together to create new road shapes. For example, you could take the bottom half of one source image and combine it with the top half of a different source image. This would allow us to use different road shapes during training while leveraging Stable Diffusion to hide the seams and generate a final image that looks cohesive.
  3. Instead of using the existing mask images, generate new ones from scratch procedurally (e.g. using a simple script to draw white lines on a black background with different patterns), then use Stable Diffusion to fill in both the “road” and the “background”. We didn’t do this in our blog post because of the time it would take to write the necessary code (remember, the whole point was to reduce the time and effort you need to generate good training data). But for some use-cases, the up-front cost of writing the additional code might be worth it.

Edge Analytics is a consulting company that specializes in data science, machine learning, and algorithm development both on the edge and in the cloud. We partner with our clients, who range from Fortune 500 companies to innovative startups, to turn their ideas into reality. Have a hard problem in mind? Get in touch at info@edgeanalytics.io.

Infinity AI. Build better AI models faster with synthetic data.

Infinity AI is an automated synthetic data platform for ML engineers. Their self-serve API and suite of tools make it easy for engineers to turn one real-world video into hundreds of similar (and perfectly labeled) synthetic videos for ML training. For open-source datasets and more info, visit infinity.ai, follow on LinkedIn, or get in touch at info@toinfinity.ai.

--

--