How to Modify OpenAI’s CLIP Model for Fine-Grained Classification
Fine-grained classification tasks can be tricky, especially when subtle differences matter. In this article, we’ll explore CLIP (Contrastive Language–Image Pretraining), a powerful model developed by OpenAI. Our goal is to equip you with a deep understanding of CLIP and its application to fine-grained classification tasks.
Here’s what we’ll cover:
- What is CLIP?: We’ll introduce the basics of CLIP and how it handles both text and images.
- How CLIP Works: A simple overview of CLIP’s architecture.
- Code Snippets: Easy-to-follow code examples.
- Using CLIP for Image Classification: Practical examples to get you started.
- Challenges in Fine-Grained Tasks: Why CLIP might struggle with certain detailed classifications.
- Solving the issue with an example of person re-identification: A real-world example of how to Improve CLIP.
By the end of this article, you’ll be ready to harness CLIP’s capabilities for your own projects. Let’s get started!
The full code implementations of this article: Google Colab
1] What is CLIP?
CLIP stands for Contrastive Language–Image Pretraining. It’s a model developed by OpenAI that combines natural language understanding with computer vision. Unlike traditional models that specialize in either text or images, CLIP can handle both simultaneously because it learned to associate images and their textual descriptions in a way that allows it to perform various tasks, including image classification, zero-shot learning, and more. The model is trained on a massive dataset of 400 million (image, text) pairs from the internet. Each pair consists of an image and its corresponding caption or description. CLIP can predict the most relevant text snippet given an image, even without direct optimization for specific tasks.
It has a simple architecture with two encoders (a vision transformer and a normal text transformer) to extract feature vectors (embeddings), followed by a vector dot product (Dot product can perform since both are embedded into a similar vector space). This allows for various tasks involving images and text through the similarity between the image and text embeddings (Zero-Shot Transfer). However, the similarity performance is poor when it comes to differentiate in fine-grained classification (within the same class).
Generally, it’s zero-shot capabilities allow it to predict classes it has never seen during training.
2] How CLIP Works
CLIP consists of two main components. Here’s a simplified overview:
- Vision Encoder: CLIP processes images using a vision encoder (Usually It can be based on either a ResNet or a (ViT) Vision Transformer). It converts the raw pixel data into a fixed-size embedding representation.
- Text Encoder: The text encoder (A normal text transformer like BERT used) processes textual descriptions (captions, prompts, or tags) and encodes them into a feature vector/embedding.
(1) Contrastive Pre-Training
Contrastive Learning: CLIP (Contrastive Language-Image Pre-Training) uses contrastive learning to align image and text representations. The model learns by contrasting positive pairs (matching image-text pairs) against negative pairs (mismatched pairs). The aim is to bring the embeddings of related image-text pairs closer together in the embedding space while pushing unrelated pairs apart.
- The Text Encoder processes textual descriptions (e.g., “Pepper the Aussie pup”) and converts them into text embeddings T1,T2,…,TN.
- Similarly, the Image Encoder processes images (e.g., a photo of a dog) and converts them into image embeddings I1,I2,…,IN.
- Contrastive Objective: During training, the model learns to maximize the similarity (dot product) between the correct image-text pairs (I1⋅T1,I2⋅T2,…,IN⋅TN) and minimize the similarity between mismatched pairs. The result is a shared embedding space where similar images and text descriptions are closely aligned.
(2) Creating a Dataset Classifier from Label Text
CLIP can perform zero-shot classification by using textual descriptions of classes. This method involves embedding the labels into the same space as the images.
- For each class label (e.g., “plane,” “car,” “dog,” “bird”), a prompt like “A photo of a {object}” is used. The text encoder converts these prompts into text embeddings T1,T2,…,TN. Let’s take ImageNet as an example: Embed each of the 1000 possible classes/objects using the prompt “a photo of a {object}” (e.g., “a photo of a dog” or “a photo of a cat”).
- These embeddings represent the classes in the shared embedding space. For example, “A photo of a dog” is embedded into the vector space, creating a representation that can be directly compared with image embeddings.
(3) Using CLIP for Zero-Shot Prediction
CLIP’s zero-shot prediction capability allows it to classify images into categories without having seen labeled examples during training.
- The Image Encoder processes the image to be classified (e.g., a photo of a dog), converting it into an image embedding I1.
- The image embedding is compared with all the class text embeddings by calculating the dot product (Similarity Calculation). The class whose embedding has the highest dot product with the image embedding is chosen as the predicted label.
For instance, if the embedding of “A photo of a dog”(T3) has the highest dot product with the image embedding I1, the image is classified as “dog.” You can even pass the dot products through the softmax function to get predicted probabilities for each class.
Zero-shot learning means that the CLIP model wasn’t explicitly trained on any of the 1.28 million training examples in the ImageNet dataset, yet it achieves accuracy comparable to the original ResNet-50, which was trained on this ImageNet data.
Pseudocode for the core implementation of CLIP
1] Image and Text Encoders:
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
2] Input Data:
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i and W_t are learned matrices that project image and text features into a common embedding space.
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - is a learned temperature parameter that scales the similarity scores.
3] Feature Extraction: extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i] produce image feature vectors `I_f`.
T_f = text_encoder(T) #[n, d_t] produce text feature vectors `T_f`.
4] joint multimodal embedding [n, d_e]
# Both image and text feature vectors are projected into a common embedding space and then normalized.
# I_e and T_e are the normalized embeddings for images and texts.
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)
5] Calculation of scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t) # matrix of similarity scores `logits`
6] Symmetric loss function
# The cross-entropy loss function aims to align the image and text embeddings symmetrically.
# It ensure that similar images and texts have higher similarity scores.
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2 # Average of the image and text losses.
Here the training is to align visual and textual representations in a shared embedding space using contrastive learning:
- Image-to-Text Similarity (s(Vi,Ti)):
A dot product of their projected embeddings, where I(xi) is the image feature from the image encoder, T(t(yi)) is the text feature from the text encoder, and gV and gT are linear projection layers to map them into a shared embedding space.
- Image-to-Text Contrastive Loss (L(i2t)) :
- Text-to-Image Contrastive Loss (L(t2i)):
- Combined Loss (L):
- Training Objective: The goal of training is to minimize this combined loss, thereby encouraging the model to assign higher similarity scores to matched image-text pairs compared to unmatched pairs.
In summary, CLIP’s simple yet effective architecture, combined with its joint embeddings and zero-shot capabilities, makes it a valuable asset for multimodal AI systems.
3] Code Snippets
Preparation for Colab or in a Local Machine
Based on where you run this notebook:
1] Preparation for a local machine
Install PyTorch 1.7.1 (or later) and torchvision. Replace cudatoolkit=11.0 with the appropriate CUDA version on your machine or cpuonly when installing on a machine without a GPU.
$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
Install additional dependencies; ftfy regex tqdm
$ pip install ftfy regex tqdm
Install the OpenAI’s CLIP repo as a Python package
$ pip install git+https://github.com/openai/CLIP.git
2] Preparation for Colab
Make sure you’re running a GPU runtime; if not, select T4 GPU
as the hardware accelerator in Runtime > Change runtime type
in the menu. The next cells will install the clip
package and its dependencies in your Colab environment.
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
Check if PyTorch 1.7.1 or later is installed.
import torch
print("Torch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
Output: Torch version: 2.3.0+cu121
Import required Python packages
import os
import numpy as np
from pkg_resources import packaging
from PIL import Image
import matplotlib.pyplot as plt
import IPython.display
Loading the CLIP model
import clip
clip.available_models()
Output:
[‘RN50’,
‘RN101’,
‘RN50x4’,
‘RN50x16’,
‘RN50x64’,
‘ViT-B/32’,
‘ViT-B/16’,
‘ViT-L/14’,
‘ViT-L/14@336px’]
These are the model architectures available in the CLIP library:
- RN50, RN101: These are ResNet models with 50 and 101 layers respectively. They are good at handling a variety of image recognition tasks.
- RN50x4, RN50x16, RN50x64: These are scaled-up versions of the ResNet-50 model, with more layers and parameters, designed to improve performance on more complex tasks.
- ViT-B/32, ViT-B/16: These are Vision Transformer (ViT) models with ‘B’ indicating the base size and the number after the slash indicating the patch size used for dividing the image into patches for processing.
- ViT-L/14: This is a larger Vision Transformer model with ‘L’ indicating large size and ‘14’ being the patch size.
- ViT-L/14@336px: This is similar to ViT-L/14 but designed for higher-resolution images (336 pixels).
Each model has its strengths and is suited for different types of tasks based on their complexity and the level of detail required. Here we choose the model “ViT-B/32”
model, preprocess = clip.load("ViT-B/32", device=device)
Let’s explore our model;
model.cuda().eval() # Move the model to GPU for faster computation and set it to evaluation mode
input_resolution = model.visual.input_resolution # Get the input image resolution that the model expects
context_length = model.context_length # Get the maximum context length of text tokens that the model can handle
vocab_size = model.vocab_size # Get the size of the vocabulary that the model uses for text processing
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408
Single Image- Multi Text example
Let’s use few of the images from skimage to give a quick test on the CLIP model;
scikit-image
(skimage
) is a Python library for image processing, offering algorithms and utilities for tasks like image I/O, filtering, and feature extraction.
skimage.data
: A submodule providing easy access to example images via functions (e.g.,skimage.data.chelsea()
).skimage.data_dir
: The filesystem directory path where these example images are stored. It’s less commonly used directly.
import skimage
images = [d for d in dir(skimage.data) if not d.startswith('_')]
print(len(images))
print(images)
43
[‘astronaut’, ‘binary_blobs’, ‘brain’, ‘brick’, ‘camera’, ‘cat’, ‘cell’, ‘cells3d’, ‘checkerboard’, ‘chelsea’, ‘clock’, ‘coffee’, ‘coins’, ‘colorwheel’, ‘create_image_fetcher’, ‘data_dir’, ‘download_all’, ‘eagle’, ‘file_hash’, ‘grass’, ‘gravel’, ‘horse’, ‘hubble_deep_field’, ‘human_mitosis’, ‘image_fetcher’, ‘immunohistochemistry’, ‘kidney’, ‘lbp_frontal_face_cascade_filename’, ‘lfw_subset’, ‘lily’, ‘logo’, ‘microaneurysms’, ‘moon’, ‘nickel_solidification’, ‘page’, ‘protein_transport’, ‘retina’, ‘rocket’, ‘shepp_logan_phantom’, ‘skin’, ‘stereo_motorcycle’, ‘text’, ‘vortex’]
There are 43 images available in this package and let’s use the cat ; chelsea.png
# Load the Chelsea image directly from skimage's data module
chelsea_image = skimage.data.chelsea()
# Convert the image to a PIL Image and then to RGB
chelsea_pil = Image.fromarray(chelsea_image).convert("RGB")
# Plot the image
plt.imshow(chelsea_pil)
plt.axis('off') # Hide the axes
plt.title("chelsea.png")
plt.show()
# image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
image = preprocess(chelsea_pil).unsqueeze(0).to(device) # Image Preprocessing
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) # Text Preprocessing
with torch.no_grad():
image_features = model.encode_image(image) # Image embedding
text_features = model.encode_text(text) # 3 Text embeddings
logits_per_image, logits_per_text = model(image, text) # matrices of similarity scores `logits`
probs = logits_per_image.softmax(dim=-1).cpu().numpy() # probabilities from a softmax
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
Label probs: [[8.640e-04 6.092e-03 9.932e-01]]
The output Label probs: [[8.640e-04 6.092e-03 9.932e-01]]
represents the probabilities that the model assigns to each of the three text labels ("a diagram", "a dog", "a cat") given the input image. Here’s a breakdown:
- The output is a single row of probabilities because there is one image.
- Each value in the row corresponds to the probability of the image matching the respective text label.
Given the output [[8.640e-04 6.092e-03 9.932e-01]]
:
8.640e-04
(or 0.000864) is the probability that the image is a diagram.6.092e-03
(or 0.006092) is the probability that the image is a dog.9.932e-01
(or 0.9932) is the probability that the image is a cat.
In this case, the model is highly confident (99.32%) that the image is a cat, which makes sense since the image is of a cat (Chelsea). The probabilities for the other labels are very low, indicating that the model correctly identifies the image content.
Multi Image- Multi Text example
Now let’s expand the example output to a multiple row of probabilities with 8 images and 8 text prompts.
1] Image Preprocessing
The second return value from clip.load()
contains a torchvision Transform
that performs the image preprocessing which includes resizing the input images and center-crop them to conform with the image resolution that the model expect. Before doing that normalize the pixel intensity using the dataset mean and standard deviation.
preprocess
Compose(
Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
CenterCrop(size=(224, 224))
<function _convert_image_to_rgb at 0x7880986cfd90>
ToTensor()
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
2] Text Preprocessing
The clip.tokenize()
is a case-insensitive tokenizer which do the Text Preprocessing. By default, the outputs are padded to become 77 tokens long, which is what the CLIP models expects.
clip.tokenize("Hello World!")
tensor([[49406, 3306, 1002, 256, 49407, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)
We can see this tokenizer is case-insensitive as the below lower case sentence also produce the same 77 long token:
clip.tokenize("hello world!")
tensor([[49406, 3306, 1002, 256, 49407, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)
With much longer sentence with the same starting Hello
and the ending with !
give a bit similar output:
clip.tokenize("Hello I'm Hasitha!")
tensor([[49406, 3306, 328, 880, 560, 564, 4715, 256, 49407, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)
Now let’s create a dictionary to assign 8 image labels from skimage.data and corresponding text prompts to represent those images.
from collections import OrderedDict
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# images in skimage to use and their textual descriptions
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}
Let’s plot the output of these 8 images with corresponding text prompts
original_images = [] # List to store original Image objects
images = [] # List to store preprocessed Image objects
texts = [] # List to store corresponding image descriptions
plt.figure(figsize=(16, 5)) # Create a figure with specific size
# Loop through files in the skimage data directory that end with .png or .jpg
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0] # Extract filename without extension
if name not in descriptions: # Skip files not listed in descriptions dictionary
continue
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB") # Open image and convert to RGB mode
plt.subplot(2, 4, len(images) + 1) # Create subplot based on number of images processed
plt.imshow(image) # Display the image
plt.title(f"{filename}\n{descriptions[name]}") # Set title with filename and its description
plt.xticks([]) # Disable x-axis ticks
plt.yticks([]) # Disable y-axis ticks
original_images.append(image) # Append original image to list
images.append(preprocess(image)) # Append preprocessed image to list
texts.append(descriptions[name]) # Append description to list
plt.tight_layout() # Adjust subplot parameters to give specified padding
image_input = torch.tensor(np.stack(images)).cuda() # Convert list of preprocessed images to a PyTorch tensor and move to GPU
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda() # Tokenize each description prefixed with "This is" using CLIP and move to GPU
with torch.no_grad():
: Temporarily disables gradient calculation to save memory and speed up computations, as we are only doing inference.image_features = model.encode_image(image_input).float()
: Uses themodel
to encode theimage_input
(tensor of preprocessed images) into image features, ensuring the result is of type float.text_features = model.encode_text(text_tokens).float()
: Uses themodel
to encodetext_tokens
(tokenized descriptions) into text features, ensuring the result is of type float.
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
Let’s have a look at the way of these 8 image embedding vectors and 8 text embedding vectors
print(image_features)
print(text_features)
tensor([[ 0.0760, 0.2096, 0.0692, …, 0.0154, -0.2122, -0.5752],
[ 0.3171, 0.3049, -0.1547, …, -0.1087, -0.2207, 0.1321],
[ 0.4111, 0.4082, -0.1124, …, 0.6768, 0.3218, 0.1974],
…,
[-0.2445, -0.1819, -0.1991, …, 0.8555, -0.1385, -0.4072],
[ 0.4045, 0.4490, -0.0288, …, 0.7642, -0.0145, 0.4531],
[-0.2820, -0.2072, 0.0348, …, 1.1123, -0.2539, -0.1091]],
device=’cuda:0')
tensor([[ 0.0688, 0.1191, -0.0773, …, 0.1407, -0.0841, -0.2223],
[-0.1661, 0.0519, -0.1539, …, 0.3145, -0.1359, -0.2651],
[-0.0849, 0.3247, 0.0920, …, 0.2379, 0.3000, 0.4390],
…,
[ 0.0723, 0.0356, -0.0352, …, 0.1862, -0.2832, 0.1443],
[ 0.3804, 0.1042, 0.3396, …, 0.4016, -0.0574, -0.0115],
[-0.0397, -0.0320, -0.0912, …, 0.2252, -0.3806, -0.4329]],
device=’cuda:0')
# Normalize image features along the last dimension
image_features /= image_features.norm(dim=-1, keepdim=True)
# Normalize text features along the last dimension
text_features /= text_features.norm(dim=-1, keepdim=True)
# Compute similarity matrix between text and image features
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
image_features /= image_features.norm(dim=-1, keepdim=True)
: Normalizesimage_features
along the last dimension (dim=-1
) to ensure each feature vector has a unit norm.text_features /= text_features.norm(dim=-1, keepdim=True)
: Normalizestext_features
along the last dimension (dim=-1
) to ensure each feature vector has a unit norm.similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
: Computes the similarity matrix between text and image features:text_features.cpu().numpy()
: Converts the normalized text features from GPU to CPU and to NumPy array.image_features.cpu().numpy().T
: Converts the normalized image features from GPU to CPU, transposes them, and converts to NumPy array.@
: Performs matrix multiplication between the text features (shape:[num_texts, feature_dim]
) and the transposed image features (shape:[feature_dim, num_images]
), resulting in a similarity matrix ([num_texts, num_images]
).
Visualizing Cosine Similarity Between Text and Image Features
This code block creates a visual representation of cosine similarity between 8 textual descriptions and 8 images:
count = len(descriptions)
: Retrieves the number of descriptions.plt.figure(figsize=(20, 14))
: Creates a large figure with dimensions 20x14 inches.plt.imshow(similarity, vmin=0.1, vmax=0.3)
: Displays the similarity matrix as an image with a color range between 0.1 and 0.3.plt.yticks(range(count), texts, fontsize=18)
: Sets the y-axis ticks to correspond to the texts (descriptions) with a specified font size.plt.xticks([])
: Disables x-axis ticks.- Annotation:
for i, image in enumerate(original_images)
: Loops through and overlays original images onto the plot at specified positions (extent
andorigin
parameters).for x in range(similarity.shape[1])
andfor y in range(similarity.shape[0])
: Iterates through the similarity matrix to annotate each cell with its similarity value usingplt.text
.plt.gca().spines[side].set_visible(False)
: Removes all spines (borders) from the plot.plt.xlim([-0.5, count - 0.5])
andplt.ylim([count + 0.5, -2])
: Sets the limits for the x-axis and y-axis respectively.plt.title("Cosine similarity between text and image features", size=20)
: Sets the title of the plot to describe the content displayed.
count = len(descriptions) # Count of descriptions
plt.figure(figsize=(20, 14)) # Create a figure with specified size
plt.imshow(similarity, vmin=0.1, vmax=0.3) # Display the similarity matrix as an image with specified vmin and vmax values
# plt.colorbar() # Uncomment this line to add a color bar indicating the scale of similarity values
plt.yticks(range(count), texts, fontsize=18) # Set y-axis ticks with text descriptions
plt.xticks([]) # Disable x-axis ticks
# Overlay original images as annotations
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
# Add text annotations for similarity values
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
# Remove spines (borders) from the plot
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5]) # Set x-axis limits
plt.ylim([count + 0.5, -2]) # Set y-axis limits
plt.title("Cosine similarity between text and image features", size=20) # Set the title of the plot
4] Using CLIP for Image Classification
To classify images, you can follow these steps with Zero-Shot Transfer:
- Encode the Image: Pass the image through the vision encoder to get its representation.
- Encode the Text: Encode a textual description (e.g., “a cat sitting on a windowsill”) using the text encoder.
- Compute Similarity: Measure the similarity between the image and text representations (e.g., using cosine similarity).
- Predict Labels: Use a linear classifier to predict labels based on the similarity scores.
Here’s a simple Python pseudocode snippet to perform zero-shot Image Classification with CLIP:
image = preprocess(Image.open("test.png")).unsqueeze(0).to(device) # Image Preprocessing
text_descriptions= [f"This is a photo of a {label}" for label in DataSet.classes]
text = clip.tokenize(text_descriptions).to(device) # Text Preprocessing
with torch.no_grad():
image_features = model.encode_image(image) # Image embedding
text_features = model.encode_text(text) # Text embeddings
logits_per_image, logits_per_text = model(image, text) # matrices of similarity scores `logits`
probs = logits_per_image.softmax(dim=-1).cpu().numpy() # probabilities from a softmax
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
# Threshold for classification
threshold = 0.7
# Predict label
for i in range(len(probs)):
if probs[i] > threshold:
predicted_label = DataSet.classes[i]
print(predicted_label)
The below example will use labels from CIFAR100 with CLIP’s Zero-Shot Image Classification [ classify images using the cosine similarity (times 100) as the logits to the softmax operation ]
from torchvision.datasets import CIFAR100 # Import the CIFAR100 dataset class from torchvision.datasets
# Create an instance of the CIFAR100 dataset with specified parameters
cifar100 = CIFAR100(
root=os.path.expanduser("~/.cache"), # Specify the root directory where the dataset will be downloaded or cached
# train=True, # Specify that this is the training set
transform=preprocess, # Apply the preprocess transformation to each image in the dataset
download=True # Download the dataset if it's not already cached locally
)
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /root/.cache/cifar-100-python.tar.gz
100%|██████████| 169001437/169001437 [00:12<00:00, 13710753.98it/s]
Extracting /root/.cache/cifar-100-python.tar.gz to /root/.cache
# Generate text descriptions for each class label in CIFAR100
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
# Tokenize text descriptions using CLIP and move to GPU
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float() # Encode text tokens into feature vectors using the model and ensure float type
text_features /= text_features.norm(dim=-1, keepdim=True) # Normalize text features along the last dimension
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) # Compute softmax probabilities of text-image similarity
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1) # Find top 5 probabilities and corresponding labels
Now let’s individually analyze the class of each picture from the previous set of 8 images from skimage.data
using classes in cifar100.classes
.
plt.figure(figsize=(16, 16)) # Create a figure with specified size
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1) # Create subplot for original image
plt.imshow(image) # Display the original image
plt.axis("off") # Turn off axis labels and ticks for cleaner visualization
plt.subplot(4, 4, 2 * i + 2) # Create subplot for probability bar chart
y = np.arange(top_probs.shape[-1]) # Generate y-axis ticks based on number of classes
plt.grid() # Enable grid lines in the plot
plt.barh(y, top_probs[i]) # Plot horizontal bar chart of top probabilities
plt.gca().invert_yaxis() # Invert y-axis to display highest probability at the top
plt.gca().set_axisbelow(True) # Ensure grid lines are behind the bars
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()]) # Set y-axis labels to class names
plt.xlabel("probability") # Set x-axis label to "probability"
plt.subplots_adjust(wspace=0.5) # Adjust horizontal space between subplots
plt.show() # Display the plot
5] Challenges in Fine-Grained Classification Tasks
Fine-grained classification tasks, such as distinguishing between similar bird species, car models, person re-identification (ReID), present unique challenges. These tasks require distinguishing between highly similar categories, often with subtle differences. Here are the main challenges faced when using CLIP for fine-grained image re-identification:
- It relies on existing image-text pairs, which may not cover all fine-grained categories.
- The vision encoder might not capture fine details.
Lack of Concrete Text Labels
In many fine-grained classification tasks, especially ReID, labels are often indexes rather than descriptive text. Traditional CLIP leverages descriptive text to create strong visual-language associations. Without concrete text labels, it’s challenging to exploit the full potential of CLIP’s vision-language capabilities.
Variability in Appearance
Fine-grained tasks involve high intra-class variability due to changes in viewpoint, lighting, pose, and occlusions. CLIP, trained on a diverse set of image-text pairs, may not be directly effective in handling these variations without task-specific adaptations.
Overfitting to Irrelevant Regions
CNN-based models, including those adapted in CLIP, can sometimes focus on irrelevant parts of the image. This issue is exacerbated in fine-grained tasks where background clutter can mislead the model into learning non-discriminative features.
Small Training Datasets
Fine-grained datasets are often smaller compared to general classification datasets. Vision transformers, like those used in CLIP, require large amounts of data for effective training. This data scarcity can lead to underfitting and suboptimal performance in fine-grained tasks.
Now let’s give a quick check on CLIP for person re-identification (ReID) on the Market1501 DataSet.
Recommend doing the following tests on your local machine:
1] Configure the Python environment with CUDA and required libraries. Use the steps below to confirm all installations in the specific environment:
python -c "import sys; print(sys.executable)"
pip list
# torch test on CUDA
python -c "import torch ; print(torch.cuda.device_count()) ; print(torch.cuda.get_device_name(0))"
2] Change the notebook directory and locate the Market-1501 dataset and choose appropriate images for the test as you wish.
dir
cd F:\GitHub\MyRepos\CLIP
3] Pip install CLIP and its requirements.
pip install ftfy regex tqdm
pip install git+https://github.com/openai/CLIP.git
4] Let’s use text prompts: [“red boy”, “boy 1”, “boy 2”] with an image of a boy from the dataset.
import torch
import clip
from PIL import Image
import matplotlib.pyplot as plt
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Load and preprocess the image
image_path = "0002_c1s1_000776_01.jpg"
image = Image.open(image_path)
# Visualize the image
plt.imshow(image)
plt.axis('off')
plt.show()
# Preprocess the image
image_tensor = preprocess(image).unsqueeze(0).to(device)
# Tokenize the text
text = clip.tokenize(["A boy with a red shirt and blue pants", "boy 1", "boy 2"]).to(device)
with torch.no_grad():
image_features = model.encode_image(image_tensor)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image_tensor, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # prints the probabilities for each label
5] Test with 5 images and 5 incorrect text prompts:
import torch
import clip
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Load and preprocess the images
image_paths = ["0002_c1s1_000776_01.jpg", "0002_c2s1_000801_01.jpg", "0023_c3s1_001951_03.jpg", "0010_c6s4_002452_02.jpg", "0011_c3s3_076119_07.jpg"]
original_images = [Image.open(image_path) for image_path in image_paths]
images = [preprocess(image).unsqueeze(0) for image in original_images]
images = torch.cat(images).to(device) # Combine into a single tensor
# Tokenize the text descriptions
texts =["red boy", "red girl","light man", "blue boy", "blue girl"]
text_tokens = clip.tokenize(texts).to(device)
with torch.no_grad():
image_features = model.encode_image(images)
text_features = model.encode_text(text_tokens)
logits_per_image, logits_per_text = model(images, text_tokens)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# Compute cosine similarity between text and image features
similarity = (image_features @ text_features.T).cpu().numpy()
# Apply softmax to the similarity matrix to ensure the values sum to 1 for each image
similarity = torch.nn.functional.softmax(torch.tensor(similarity), dim=1).numpy()
# Visualization
count = len(texts) # Count of descriptions
plt.figure(figsize=(20, 14)) # Create a figure with specified size
plt.imshow(similarity, vmin=0, vmax=1) # Display the similarity matrix as an image with specified vmin and vmax values
# plt.colorbar() # Uncomment this line to add a color bar indicating the scale of similarity values
plt.yticks(range(count), texts, fontsize=18) # Set y-axis ticks with text descriptions
plt.xticks([]) # Disable x-axis ticks
# Overlay original images as annotations
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
# Add text annotations for similarity values
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
# Remove spines (borders) from the plot
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5]) # Set x-axis limits
plt.ylim([count + 0.5, -2]) # Set y-axis limits
plt.title("Cosine similarity between text and image features", size=20) # Set the title of the plot
plt.show()
print("Label probs:", probs) # prints the probabilities for each label
Obviously, the results are poor and do not match.
6] Let’s use 4 improved text prompts. I’m going to give my own texts describing each image from left to right as below:
- “A boy wearing a red shirt and blue jeans”
- “A boy with a red shirt and blue pants”
- “A man with a striped shirt and white jeans”
- “A girl with red blouse and light blue skirt”
- “A girl carrying a red backpack and wearing a white shorts”
import torch
import clip
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
# Load and preprocess the images
image_paths = ["0002_c1s1_000776_01.jpg", "0002_c2s1_000801_01.jpg", "0023_c3s1_001951_03.jpg", "0010_c6s4_002452_02.jpg", "0011_c3s3_076119_07.jpg"]
original_images = [Image.open(image_path) for image_path in image_paths]
images = [preprocess(image).unsqueeze(0) for image in original_images]
images = torch.cat(images).to(device) # Combine into a single tensor
# Tokenize the text descriptions
texts =["A boy wearing a red shirt and blue jeans", "A boy with a red shirt and blue pants","A man with a striped shirt and white jeans", "A girl with red blouse and light blue skirt", "A girl carrying a red backpack and wearing a white shorts"]
text_tokens = clip.tokenize(texts).to(device)
with torch.no_grad():
image_features = model.encode_image(images)
text_features = model.encode_text(text_tokens)
logits_per_image, logits_per_text = model(images, text_tokens)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
# Compute cosine similarity between text and image features
similarity = (image_features @ text_features.T).cpu().numpy()
# Apply softmax to the similarity matrix to ensure the values sum to 1 for each image
similarity = torch.nn.functional.softmax(torch.tensor(similarity), dim=1).numpy()
# Visualization
count = len(texts) # Count of descriptions
plt.figure(figsize=(20, 14)) # Create a figure with specified size
plt.imshow(similarity, vmin=0, vmax=1) # Display the similarity matrix as an image with specified vmin and vmax values
# plt.colorbar() # Uncomment this line to add a color bar indicating the scale of similarity values
plt.yticks(range(count), texts, fontsize=18) # Set y-axis ticks with text descriptions
plt.xticks([]) # Disable x-axis ticks
# Overlay original images as annotations
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
# Add text annotations for similarity values
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
# Remove spines (borders) from the plot
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5]) # Set x-axis limits
plt.ylim([count + 0.5, -2]) # Set y-axis limits
plt.title("Cosine similarity between text and image features", size=20) # Set the title of the plot
plt.show()
print("Label probs:", probs) # prints the probabilities for each label
The results are better matched than the previous text prompts, but they are still not sufficient for separation using threshold values.
6] Solving the issue with an example of person re-identification.
To address the CLIP model’s challenges in fine-grained classification tasks, such as image re-identification (ReID), Classifing species of flowers, or distinguishing car models, we can enhance it’s text prompts and fine-tune the image encoder.
The primary issues and their solutions are as follows:
- Optimizing Learnable Text Tokens: The current text prompts may not cover all the necessary detailed categories. To improve the text representation, we need to optimize learnable text tokens. These tokens contribute to better text features, which in turn enhance the overall model performance.
- Fine-Tuning the Image Encoder: The vision encoder might miss fine details. To address this, we should fine-tune the image encoder specifically for fine-grained categories.
This alignment between improved text context and improved image features enables the model to recognize detailed categories more effectively.
Researchers have been actively working on this problem for 3 years, particularly in the areas of multimodal prompting and fine-tuning encoders with minor adjustments to the CLIP architecture. Some papers related to this topic include:
- CLIP-ReID: Exploiting Vision-Language Model for Image Re-identification without Concrete Text Labels (AAAI)
- TF-CLIP: Learning Text-Free CLIP for Video-Based Person Re-identification (AAAI)
- Conditional Prompt Learning for Vision-Language Models (CVPR)
- Fine-Grained Visual Prompting (NeurIPS)
- Delving into Multimodal Prompting for Fine-Grained Visual Classification (AAAI)
In this article, I refer to the paper titled “CLIP-ReID: Exploiting Vision-Language Model for Image Re-Identification without Concrete Text Labels” by Siyuan Li, Li Sun, and Qingli Li. This paper proposes a two-stage training approach that directly tackles the above 2 issues. In the first stage, learnable text tokens are optimized to enhance text features. In the second stage, these improved text tokens are used to fine-tune the image encoder. As a result, the method significantly improves the model’s performance across various ReID datasets, including Market-1501, MSMT17, DukeMTMC-reID, Occluded-Duke, VehicleID, and VeRi-776.
Overview of this new approach compared to CLIP and CoOp
(a) CLIP Model:
(b) CoOp Model:
- Adaptation Strategy: CoOp (Context Optimization) builds on the pre-trained CLIP model but focuses on fine-tuning. Instead of retraining the entire model, CoOp fixes both the image encoder and text encoder, leveraging their pre-trained capabilities.
- Fine-Tuning Text Prompts: CoOp fine-tunes the text prompts to better suit the downstream dataset. This fine-tuning allows CoOp to adapt the pre-trained model to specific tasks by optimizing the context in which the textual descriptions are used without altering the encoders themselves.
c) CLIP-ReID Method:
- Two-Stage Training Process:
1_First Stage:
2_ Second Stage:
- Benefits: This two-stage approach allows the model to effectively use the strengths of the pre-trained CLIP encoders while specifically adapting the text features and subsequently the visual features for the ReID task. By fixing the encoders initially and optimizing in stages, CLIP-ReID achieves a better balance between generalization and task-specific adaptation.
Here’s a well-explained Algorithm for the training process;
Now let’s have a look at what happens to 10 randomly selected people in the MSMT17 dataset (represented by 10 different colors; dots — image features, pentagons — text features).
t-SNE visualization on image and text features through the 2-stage process:
This t-SNE visualization clearly demonstrates the effectiveness of the two-stage training process employed by CLIP-ReID. In the first stage, the learnable text tokens are optimized to create ambiguous descriptions that adapt the text features closer to their corresponding image features. This alignment is evident as the text features (pentagons) start to cluster around the image features (dots). By the second stage, the image encoder is fine-tuned with these optimized text features, leading to a more discriminative and well-separated distribution of image features across different identities. This enhanced alignment and separation significantly improve the performance of the image re-identification task, as the visual representations become more robust and accurately reflective of the unique identities.
Implementation with Market1501 Dataset
Below is an implementation of the proposed method using the Market1501 dataset to train the model and evaluate the model to use the pre-trained ViT-CLIP-ReID-SIE-OLP model.
Step-by-Step Guide
1. Setup and Installation
First, clone the repository and install the required dependencies.
git clone https://github.com/Syliz517/CLIP-ReID.git
cd CLIP-ReID
pip install -r requirements.txt
python -c "import sys; print(sys.executable)"
pip install yacs
pip install timm
pip install scikit-image
pip install tqdm
pip install ftfy
pip install regex
2. Data Preparation
Download the Market1501 dataset and place it in the datasets
directory. The dataset structure should look like this:
datasets/
└── Market1501/
├── bounding_box_test/
├── bounding_box_train/
├── query/
└── ...
3. Training the Model
The training involves two stages. In the first stage, the image and text encoders from CLIP are fixed, and only the text tokens are optimized. In the second stage, the text tokens are fixed, and the image encoder is fine-tuned.
Here’s the code to train the model:
import os
from clip_reid import train
# Configuration for training
config = {
"dataset": "Market1501",
"data_dir": "./datasets/Market1501",
"model": "vit_base_patch16_224",
"epochs_stage1": 10,
"epochs_stage2": 20,
"batch_size": 32,
"learning_rate": 0.0003,
"output_dir": "./output",
}
# Train the model
train(config)
Or edit the configs/person/cnn_base.yml
acording to your choice and run CUDA_VISIBLE_DEVICES=0 python train_clipreid.py --config_file configs/person/vit_clipreid.yml
cd F:\GitHub\MyRepos\CLIP\CLIP-ReID
dir
Directory of F:\GitHub\MyRepos\CLIP\CLIP-ReID
22/06/2024 08:30 pm <DIR> .
22/06/2024 11:53 pm <DIR> ..
22/06/2024 08:30 pm 201 .gitignore
22/06/2024 08:30 pm <DIR> config
22/06/2024 08:30 pm <DIR> configs
22/06/2024 08:30 pm <DIR> datasets
22/06/2024 08:30 pm <DIR> fig
22/06/2024 08:30 pm 1,086 LICENSE
22/06/2024 08:30 pm <DIR> loss
22/06/2024 08:30 pm <DIR> model
22/06/2024 08:30 pm <DIR> processor
22/06/2024 08:30 pm 9,973 README.md
22/06/2024 08:30 pm <DIR> solver
22/06/2024 08:30 pm 2,650 test.py
22/06/2024 08:30 pm 2,676 test_clipreid.py
22/06/2024 08:30 pm 3,015 train.py
22/06/2024 08:30 pm 3,760 train_clipreid.py
22/06/2024 08:30 pm <DIR> utils
7 File(s) 23,361 bytes
11 Dir(s) 108,303,978,496 bytes free
Run;
CUDA_VISIBLE_DEVICES=0 python train_clipreid.py --config_file configs/person/vit_clipreid.yml
4. Evaluation
After training, you can evaluate the model performance on the test set.
from clip_reid import evaluate
# Configuration for evaluation
eval_config = {
"dataset": "Market1501",
"data_dir": "./datasets/Market1501",
"model_path": "./output/best_model.pth",
}
# Evaluate the model
evaluate(eval_config)
Then Run;
CUDA_VISIBLE_DEVICES=0 python test_clipreid.py --config_file configs/person/vit_clipreid.yml TEST.WEIGHT 'F:\GitHub\MyRepos\CLIP\tests\ViT-B-16_60.pth'References
Results
The Market1501 dataset contains 32,668 images of 1,501 identities captured by six cameras, which poses significant challenges due to variations in lighting, occlusion, and background clutter.
The tests demonstrated that the two-stage training process of CLIP-ReID significantly outperforms existing methods. Specifically, CLIP-ReID achieved a notable improvement in mean Average Precision (mAP) and Rank-1 accuracy, reaching 89.8% mAP and 95.7% R1 accuracy compared to the baseline model.
This improvement underscores the ability of CLIP-ReID to improve cross-modal learning effectively, resulting in more robust and discriminative feature representations. The innovative use of learnable text tokens to create ambiguous text descriptions in the first stage, followed by the fine-tuning of the image encoder in the second stage, has proven to be a successful strategy for enhancing image re-identification tasks. These results not only validate the superior performance of CLIP-ReID on fine-grained classification tasks but also highlight its potential for broader applications in various domains requiring precise image recognition. In conclusion, CLIP-ReID represents a significant advancement in the field of image re-identification, offering a robust solution that bridges the gap between vision and language models, and sets a new benchmark for future research and applications.
References
- Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., … & Sutskever, I. (2021). Learning Transferable Visual Models From Natural Language Supervision by OpenAI. arXiv preprint arXiv:2103.00020.
- Li, S., Sun, L., & Li, Q. (2023). CLIP-ReID: Exploiting Vision-Language Model for Image Re-Identification without Concrete Text Labels. Proceedings of the AAAI Conference on Artificial Intelligence, 37(7), 10007–10014. Shanghai Key Laboratory of Multidimensional Information Processing, East China Normal University.
- Zhou, K., Yang, Z., Loy, C. C., & Liu, Z. (2022). Conditional Prompt Learning for Vision-Language Models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 16816–16825).
- Jia, X., Yang, Y., Zhang, Y., Xie, W., & Wei, Y. (2022). Learning to Prompt for Vision-Language Models. International Journal of Computer Vision (IJCV), 130, 1835–1848.
- He, S., Luo, H., Wang, P., Wang, F., Li, H., & Jiang, W. (2021). TransReID: Transformer-based Object Re-Identification. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) (pp. 15013–15022).
- Wang, Q., Liu, Z., Zhang, C., Guo, H., & Yao, J. (2023). TF-CLIP: Learning Text-Free CLIP for Video-Based Person Re-identification. Proceedings of the AAAI Conference on Artificial Intelligence, 37(9), 10972–10979.
- Zheng, L., Shen, L., Tian, L., Wang, S., Wang, J., & Tian, Q. (2017). Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. arXiv preprint arXiv:1711.08565. (Market1501 Dataset)
- Liu, M., Wang, Y., Zhang, X., Gao, J., Shen, Y., Zhang, W., & Li, H. (2023). Fine-Grained Visual Prompting. Advances in Neural Information Processing Systems (NeurIPS), 36.
- Li, X., Ma, L., Chen, X., Sun, Z., & Zhang, H. (2023). Delving into Multimodal Prompting for Fine-Grained Visual Classification. Proceedings of the AAAI Conference on Artificial Intelligence, 37(7), 9307–9314.