Lets look at meta AI’s SAM project and implement it with python .

Amir Shakiba
4 min readApr 12, 2023

--

segment anything model is about to change computer vision in a good way !

Segment Anything Model (SAM) is a Facebook AI model designed to generalize segmentation. In our previous post, we discussed general information about SAM, and now let’s dive deeper into its technical details.

my other posts related to SAM:
Instance segmentation with SAM and SAM and stable diffusion

the structure of SAM model

As shown in the graph, the image goes through an encoder to obtain its embedding, and then any mask can be implemented. The prompt can be in the form of text, a bounding box, or free-form points. We encode our prompt and pass it, along with the image embedding, to our decoder, which generates our masks.

One of the most interesting features of SAM is its lightweight encoder and decoder, allowing for real-time performance. You can use SAM in Python by using the packaged version available on GitHub:
https://github.com/kadirnar/segment-anything-video

However, if you’re experiencing issues working with it, you can use the Colab file available on the original GitHub page. Here’s how you can get started:

using_colab = True 
if using_colab:
import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'


!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

First, import Torch and Torchvision, which are necessary for the project, and then install Segment Anything using pip. Download a model checkpoint from https://github.com/facebookresearch/segment-anything#model-checkpoints which we use later on .

Next, create an images directory where you can put your test image.You can also use your own image by replacing the URL in the following command:

!mkdir images
!wget -O images/image.jpg https://live.staticflickr.com/65535/49894878561_14a39c6c35_b.jpg

Once you have your image, you can import the necessary packages, including numpy, Torch, Matplotlib, and OpenCV.

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

You can use the following function to plot annotations:

def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))

Now, read the image using OpenCV and change the channels from BGR to RGB. Then, display the image using Matplotlib.

image = cv2.imread('images/image.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()
source :https://live.staticflickr.com/65535/49894878561_14a39c6c35_b.jpg

To create your mask generator, you’ll need to define your sam model and use SamAutomaticMaskGenerator :


import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)

our SAM model registery takes the checkpoint and model type to give us our model ,remember to set your runtime to GPU .
SamAutomaticMaskGenerator takes your model to make your mask_generator .

All you need to do is pass your input to this function to get your mask

masks = mask_generator.generate(image)

The mask object contains multiple pieces of information on the area and stability score, and labels will be added to this mask object later on.

let’s have a look at the output :

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
our masked output.

You can also adjust the parameters of your mask generator by changing the following variables:

mask_generator_2 = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32,
pred_iou_thresh=0.86,
stability_score_thresh=0.92,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Requires open-cv to run post-processing
)
masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()
another output with different parameters

Thank you for taking the time to read my post! I hope you found it informative and interesting. If you enjoyed reading it and would like to stay updated on my latest posts, please hit the ‘Follow’ button. And if you found it helpful, please consider giving it a round of applause by hitting the ‘Clap’ button multiple times. Your support is greatly appreciated and will motivate me to continue my quest on studying AI. Thank you again for your time and support!

you can find the source code here :

https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb

--

--