How To Fine-Tune Segment Anything

Encord
Encord
Published in
6 min readAug 9, 2023

Computer vision is having its ChatGPT moment with the release of the Segment Anything Model (SAM) by Meta last week. Trained over 11 billion segmentation masks, SAM is a foundation model for predictive AI use cases rather than generative AI. While it has shown an incredible amount of flexibility in its ability to segment over wide-ranging image modalities and problem spaces, it was released without “fine-tuning” functionality.

This tutorial will outline some of the key steps to fine-tune SAM using the mask decoder, particularly describing which functions from SAM to use to pre/post-process the data so that it’s in good shape for fine-tuning.

Update: By popular demand — we’ve included a full Colab Notebook with all the code you need to fine-tune SAM. The link can be found reading on 👇

Scale your annotation workflows and power your model performance with data-driven insights

Try Encord today

What is the Segment Anything Model (SAM)?

The Segment Anything Model (SAM) is a segmentation model developed by Meta AI. It is considered the first foundational model for Computer Vision. SAM was trained on a huge corpus of data containing millions of images and billions of masks, making it extremely powerful. As its name suggests, SAM is able to produce accurate segmentation masks for a wide variety of images. SAM’s design allows it to take human prompts into account, making it particularly powerful for Human In The Loop annotation. These prompts can be multi-modal: they can be points on the area to be segmented, a bounding box around the object to be segmented or a text prompt about what should be segmented.

The model is structured into 3 components: an image encoder, a prompt encoder, and a mask decoder.

Source

The image encoder generates an embedding for the image being segmented, whilst the prompt encoder generates an embedding for the prompts. The image encoder is a particularly large component in the model. This is in contrast to the lightweight mask decoder, which predicts segmentation masks based on the embeddings. Meta AI has made the weights and biases of the model trained on the Segment Anything 1 Billion Mask (SA-1B) dataset available as a model checkpoint.

What is Model Fine-Tuning?

Publicly available state of the art models have a custom architecture and are typically supplied with pre-trained model weights. If these architectures were supplied without weights then the models would need to be trained from scratch by the users, who would need to use massive datasets to obtain state of the art performance.

Model fine-tuning is the process of taking a pre-trained model (architecture+weights) and showing it data for a particular use case. This will typically be data that the model hasn’t seen before, or that is underrepresented in its original training dataset.

The difference between fine-tuning the model and starting from scratch is the starting value of the weights and biases. If we were training from scratch, these would be randomly initialized according to some strategy. In such a starting configuration, the model would ‘know nothing’ of the task at hand and perform poorly. By using pre-existing weights and biases as a starting point we can ‘fine tune’ the weights and biases so that our model works better on our custom dataset. For example, the information learned to recognize cats (edge detection, counting paws) will be useful for recognizing dogs.

Why Would I Fine-Tune a Model?

The purpose of fine-tuning a model is to obtain higher performance on data that the pre-trained model has not seen before. For example, an image segmentation model trained on a broad corpus of data gathered from phone cameras will have mostly seen images from a horizontal perspective.

If we tried to use this model for satellite imagery taken from a vertical perspective, it may not perform as well. If we were trying to segment rooftops, the model may not yield the best results. The pre-training is useful because the model will have learned how to segment objects in general, so we want to take advantage of this starting point to build a model which can accurately segment rooftops. Furthermore, it is likely that our custom dataset would not have millions of examples, so we want to fine-tune instead of training the model from scratch.

Fine-tuning is desirable so that we can obtain better performance on our specific use case, without having to incur the computational cost of training a model from scratch.

How to Fine-Tune Segment Anything Model [With Code]

Background & Architecture

We gave an overview of the SAM architecture in the introduction section. The image encoder has a complex architecture with many parameters. In order to fine-tune the model, it makes sense for us to focus on the mask decoder which is lightweight and therefore easier, faster, and more memory efficient to fine-tune.

In order to fine-tune SAM, we need to extract the underlying pieces of its architecture (image and prompt encoders, mask decoder). We cannot use SamPredictor.predict ( link) for two reasons:

  • We want to fine-tune only the mask decoder
  • This function calls SamPredictor.predict_torch which has the @torch.no_grad() decorator (link), which prevents us from computing gradients

Thus, we need to examine the SamPredictor.predict function and call the appropriate functions with gradient calculation enabled on the part we want to fine-tune (the mask decoder). Doing this is also a good way to learn more about how SAM works.

Creating a Custom Dataset

We need three things to fine-tune our model:

  • Images on which to draw segmentations
  • Segmentation ground truth masks
  • Prompts to feed into the model

We chose the stamp verification dataset ( link) since it has data that SAM may not have seen in its training (i.e., stamps on documents). We can verify that it performs well, but not perfectly, on this dataset by running inference with the pre-trained weights. The ground truth masks are also extremely precise, which will allow us to calculate accurate losses. Finally, this dataset contains bounding boxes around the segmentation masks, which we can use as prompts to SAM. An example image is shown below. These bounding boxes align well with the workflow that a human annotator would go through when looking to generate segmentations.

Input Data Preprocessing

We need to preprocess the scans from numpy arrays to pytorch tensors. To do this, we can follow what happens inside SamPredictor.set_image ( link) and SamPredictor.set_torch_image ( link) which preprocesses the image. First, we can use utils.transform.ResizeLongestSide to resize the image, as this is the transformer used inside the predictor ( link). We can then convert the image to a pytorch tensor and use the SAM preprocess method ( link) to finish preprocessing.

Training Setup

We download the model checkpoint for the vit_b model and load them in:

sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')

We can set up an Adam optimizer with defaults and specify that the parameters to tune are those of the mask decoder:

optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters())

At the same time, we can set up our loss function, for example, Mean Squared Error

loss_fn = torch.nn.MSELoss()

Training Loop

In the main training loop, we will be iterating through our data items, generating masks, and comparing them to our ground truth masks so that we can optimize the model parameters based on the loss function.

In this example, we used a GPU for training since it is much faster than using a CPU. It is important to use .to(device) on the appropriate tensors to make sure that we don’t have certain tensors on the CPU and others on the GPU.

We want to embed images by wrapping the encoder in the torch.no_grad() context manager, since otherwise we will have memory issues, along with the fact that we are not looking to fine-tune the image encoder.

with torch.no_grad(): image_embedding = sam_model.image_encoder(input_image)

Discover the Colab notebook containing the code to fine-tune SAM by reading the blog How To Fine-Tune Segment Anything!

Supercharge Your Annotations with the

Label with SAM in Encord

Originally published at https://encord.com.

--

--

Encord
Encord
Editor for

The active learning platform for computer vision.