Segment Anything Model (SAM): Explained

Utkarsh Doshi
6 min readDec 25, 2023

--

In simple words it’s a foundational model in computer vision. Below are the examples when you ask the model to provide a mask for everything.

original and masked image
original and masked image

I think you probably are as stunned as I am.

Thanks to open source people are now using it in numerous ways, one which stands out for me is shown in the above image. Write a text prompt and use image generation models to inpaint it.

Modified version of SAM being used along with stable diffusion for inpainting.

Before you go through this article I would recommend you to watch this video by Letitia Parcalabescu. She has explained the Masked Autoencoders which SAM uses as an image encoder to extract image embeddings.

In the following sections, I will delve into the architecture of the model, the process of creating the dataset, a bit on process of building the dataset and a bit on its zero-shot capabilities.

Architecture:

In a nutshell model consists of:

  1. Image encoder (masked autoencoder) to extract the image embedding,
  2. Prompt encoder that takes in different types of prompt and
  3. Mask decoder to build a mask.

Large Language Models (LLMs) pre-trained on extensive text data from the web, are revolutionizing NLP with their impressive zero-shot and few-shot generalization. These models are usually referred to as “foundational models”. However, in the computer vision domain, there has been limited exploration in developing foundational models. CLIP and ALIGN are among the few existing models in this space. So the main motivation for this paper was to develop a foundational model for image segmentation, a model that can take a prompt of different styles and which has powerful generalization abilities.

Image Encoder:

The core of the model is a Masked autoencoder which utilizes a vision transformer to achieve high scalability. They employ a ViT-H/16 which is a huge vision transformer model that handles a 16x16 patch size. It features a 14×14 windowed attention and four equally-spaced global attention blocks.

The output from the above encoder is feature embedding that is a 16x downscaled version of the original image. This downsizing process is crucial for efficient processing while retaining essential image features. The model takes an input resolution of 1024x1024x3, typical for high-resolution images, and transforms it into a dense embedding of size 64x64x256.

Prompt Encoder:

Prompt encoder

Different types of prompts namely: points, boxes, masks and text.

There are two types of prompts:

  1. sparse which include points, boxes and text and
  2. dense which includes a masks

Point is represented as a sum of positional encoding of the points’s location and one of the two learned embeddings to indicate either a foreground point or background point.

Boxes are represented by an embedding pair. The positional encoding of its top-left corner is summed with a learned embedding representing “top-left corner” and similarly for bottom right corner.

For the text they use CLIP, they don’t modify this and therefore any text encoder can be used.

Dense prompts are masks therefore they have spatial correspondence with the images. These masks are downscaled by a factor of 4 before inputting them into the model. Inside, they are again downscaled by the factor of 4, achieved through two 2x2, stride-2 convolutions, with the output channels as 4 and 16. Subsequently a final 1x1 convolution is applied that maps these channels to 256 channels. Each layer is enhanced with Gelu activation and layer normalization. The mask embedding is then added element-wise to the image embedding. In cases where no mask prompt is provided, a learned embedding, representing ‘no mask,’ is added to each image embedding location.

Light Weight Mask Decoder:

output mask decoder

This is the place where all the magic happens, here image embeddings and prompt embeddings are mapped to the final mask. To build this they take some inspiration from Transformer segmentation models and accordingly modify the Transformer decoder.

A key aspect of this process involves the introduction of a learned output token embedding into the prompt embedding before it is processed by the decoder. This output token embedding plays a pivotal role in the decoder’s function, containing essential information required for the overall image segmentation task. This concept is similar to the use of class tokens in Vision Transformers for image classification. In image classification, these class tokens are crucial as they encapsulate information about the overall image. Similarly, in our model, the output token embedding serves as a critical element that guides the decoding process towards effective image segmentation.

Each decoder layer performs 4 steps (as can be seem from the figure above):

(1) self-attention on the tokens,

(2) cross-attention from tokens (as queries) to the image embedding,

(3) a point-wise MLP updates each token, and

(4) cross-attention from the image embedding (as queries) to tokens.

This last step updates the image embedding with prompt information. During cross-attention, the image embedding is treated as a set of 64^2 256-dimensional vectors.

Building the dataset:

Building the dataset for this model was a tough task as they described it in paper. Typical approach to the training foundation model that is taken while training llms is by taking the dataset from the internet, but masks are not easily available like who would make a mask of their images ? So they come up with this:-

Data Engine: They develop the model with model-in-the-loop dataset annotation . Data engine has three stages: assisted-manual, semi-automatic, and fully automatic.

  1. In the first stage, SAM assists annotators in annotating masks, similar to I would say a roboflow setup.
  2. In the second phase, SAM has the capability to automatically generate masks for specific objects once prompted with their locations. This allows annotators to focus on creating masks for other objects that SAM cannot handle automatically.
  3. At this final stage, SAM is prompted with a regular grid of foreground points. This method efficiently produces approximately 100 high-quality masks per image.

One of the contributions of the paper is SA-1B (name for the above dataset) which has over a billion masks for 11 million licensed privacy-preserving images. Final model is the model trained on the masks generated form the third step.

Zero shot transfer results:

Some examples of zero shot transfers where SAM excels. These are directly taken from the paper.

Samples from the 23 diverse segmentation datasets used to evaluate SAM’s zero-shot transfer capabilities.
Zero-shot edge prediction on BSDS500. SAM was not trained to predict edge maps nor did it have access to BSDS images or annotations during training.
Visualization of thresholding the similarities of mask embeddings from SAM’s latent space. A query is indicated by the magenta box; top row shows matches at a low threshold, bottom row at a high threshold. The most similar mask embeddings in the same image can often be semantically similar to the query mask embedding, even though SAM is not trained with explicit semantic supervision.

I have done my best to explain the paper. Please correct me where you think I may be mistaken. I would appreciate any feedback, and I’m open to suggestions for the next paper I should read and write a summary on.

Finally, thank you for taking time to read this paper! 😄

https://www.buymeacoffee.com/utkarsh135a

--

--