Building CLIP Model from scratch using PyTorch | Contrastive Language-Image Pre-Training

Shubh Mishra
The Deep Hub
Published in
9 min read5 days ago

Hey 👋

In 2021 OpenAI released a paperLearning Transferable Visual Models From Natural Language Supervision" which proposed the CLIP (Contrastive Language-Image Pre-Training), a powerful deep-learning model designed to understand and interpret images and text in a unified manner. It combines vision and language encoders to connect textual descriptions with visual content. The CLIP model does not generate a description for the image itself but can be used to assess the relationship between the text and the image. For example, you can provide an image of a cat along with a list of labels such as “cat” and “dog” to determine which label has the highest probability of matching the image.

This story today covers the implementation of CLIP from scratch using PyTorch.

Image Source: Author

CLIP (Contrastive Learning-Image Pretraining)

  • Traditional machine learning models often require large, task-specific labeled datasets for fine-tuning. For example, a model trained to identify dogs might not perform well in identifying cats unless it’s specifically fine-tuned on cat images.
  • CLIP’s architecture enables zero-shot learning, meaning it can perform tasks it wasn’t directly trained on by leveraging its broad, learned associations between images and text. For instance, based on their textual descriptions, it can classify pictures it has never seen during training. Quoting their paper “We match the accuracy of the original ResNet-50 on ImageNet zero-shot without needing to use any of the 1.28 million training examples it was trained on.”
Figure 2: CLIP Model | Source: Paper

Clip has the following components that we’ll need to build

  1. Text Encoder
  2. Image Encoder
  3. Custom Dataset (In case you are training)
  4. Symmetric Loss

Text Encoder

As our main motive is to align the embeddings for both textual and visual representations, we will need a text encoder model to create features for the textual description of the image. I will not be covering how to build a text encoder from scratch and use the transformer library to create the encoder, though this will cover the main idea behind the CLIP implementation.

For simplicity using the Distil Bert model would be a fine approach as it is lightweight and has performance almost as good as the standard BERT model, with a similar base architecture. This is something to keep in mind, that we are not loading the pre-trained version.

class TextEncoder(nn.Module):
def __init__(self, embed_dim, proj_dim):
super().__init__()
self.model = DistilBertModel(config=DistilBertConfig())
self.layer_norm = nn.LayerNorm(proj_dim)

def forward(self, input_ids, attention_mask):
x = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
return self.layer_norm(x)

TextEncoder() class will be expecting two inputs, input_ids and attention_mask both will be generated through the tokenizer.

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
texts = ["This is a sample sentence.", "This is another example."]
inputs= tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)

encoder = ImageEncoder(embed_dim=768, proj_dim=256)
inputs = encoder(inputs['input_ids'], inputs['mask'])

Now, the output of the forward pass for TextEncoder will be (Batch_Size, Token_Size + 1, Embed_Size), in the standard BERT architecture the models aim for two tasks, to output an extra CLS_Token prepended to the original tokens, which in general is used to further fine-tune for classification tasks, and to predict the masked token with the information of all the tokens before and after the masked one.

Figure: Standard BERT| Source: Wiki

As we are concerned with getting feature embeddings for are textual data, we’ll only take the [CLS] token and project it to a common space, with the same embedding size as the visual embeddings from the Image Encoder.

class TextEncoder(nn.Module):
def __init__(self, embed_dim, proj_dim):
super().__init__()
self.model = DistilBertModel(config=DistilBertConfig())
self.projection = nn.Linear(embed_dim, proj_dim)
self.layer_norm = nn.LayerNorm(proj_dim)

def forward(self, input_ids, attention_mask):
x = self.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state

x = x[:, 0, :] # B, T[cls], E

x = self.projection(x)

return self.layer_norm(x)

Layer Normalization is a very common concept in deep learning, and it’s not the first time I’ve explained it, but let’s just go with this one more time, we have input to a network, which contains data from different classes or features, as during every training epoch the batch changes so do the distribution of data, at one batch the distribution could range [0, 2) and during the next batch it could have samples distributed in range [0, 100]. This change in data distribution during training is known as Covariate shift. As the input changes drastically so will the output, and so will the loss, if the loss changes drastically, then during backpropagation the weights will be updated with higher magnitude, leading to unsmooth gradients. In short, normalizing an input will confine its distribution throughout training every batch and thus, no drastic changes in loss will lead to smoother gradients and faster training, helping the model to focus more on learning features.

Image Encoder

CLIP has two image encoder options, a ResNet or a Vision Transformer. We have already developed various Vision Transformers, and will therefore be using the standard implementation.

If you want to study the implementation of ViT from scratch in great detail, you can follow the Vision Transformer Series where I’ve implemented various ViTs and have explained their architectures from the ground up.

If you want to use ResNet as an image encoder you can simply replace it with the Vision Transformer model, you could either use from PyTorchs` own ResNet model or timm.

class ImageEncoder(nn.Module):
def __init__(self, base_model, embed_dim, proj_dim):
super().__init__()

self.model = base_model

for param in self.model.parameters():
param.requires_grad = True

self.projection = nn.Linear(embed_dim, proj_dim)
self.layer_norm = nn.LayerNorm(proj_dim)

def forward(self, x):
x = self.projection(self.model(x))
return self.layer_norm(x)

The encoder class above passes the image tensor to the model which is then projected to a common embedding space same as that of the Text Encoders output, following a normalization layer.

Custom Dataset

Now CLIP is a (pretty) dense model, so if you want to train it from scratch you’ll have to train it on a small dataset. Since this article is only about how to implement the architecture from scratch, we will not go down into further details of creating a dataset, but for the sake of example, this is how you might want to do it.

class CustomDataset(Dataset):
def __init__(self, texts, image_paths):

self.image_paths = image_paths
self.texts = texts
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
self.inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
self.transform = torchvision.transforms.ToTensor()

def __len__(self):
return len(self.texts)

def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path)
image = self.transform(image)

caption, mask = self.inputs[idx].items()

return {
"image": image,
"input_ids": caption["input_ids"],
"mask": mask["attention_mask"]
}

image_paths: The list of paths for the images in the your chosen dataset.

texts: The caption or textual-description for each image in the dataset.

The Custom Dataset class creates the tokenizer and tokenizes all the texts when the Dataset class is called, we are using the distillbert tokenizer, which is also our text encoder model.

Putting it all together

class CLIPModel(nn.Module):
def __init__(self):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
ViT = VissionTransformer(
num_layers=8,
img_size=224,
emb_size=768,
patch_size=16,
num_head=6,
num_class=768).to(self.device)
self.image_encoder = ImageEncoder(base_model=ViT, embed_dim=768, proj_dim=256)
self.text_encoder = TextEncoder(embed_dim=768, proj_dim=256)

The ClipModel() class is where we’ll put it all together, the architecture will include the embeddings from both Image and Text encoders, which are then used for calculating the symmetric loss. This is the Numpy-like pseudocode for the core implementation of CLIP.

Figure 4: Numpy-like pseudocode for the core of an implementation of CLIP.

In our implementation, we are going to calculate the loss within the forward function of are CLIPModel class.

The very first step is to get the image and text embeddings, which should then be cross-multiplied to get the similarity matrix or logits. Going back to our 2nd figure.

Figure: 5

The logits are created by taking a dot product of the image and text embedding, as this paper is based on contrastive learning, our primary goal is to align the textual representation with the visual. So how does calculating a similarity matrix help?

The answer is each Image token received from the image encoder (Figure 5: I_1, I_2,.., I_n; where I is the embedding and n is batch size) is multiplied by each token received from the text encoder. Resulting in the final matrix (B, Token, Embed)@(B, Embed, Token) → (B, Token, Token). Now our task is to maximize the value of each diagonal element (I1T1, I2T2,…, InTn). As we want to align our textual and visual representation, the corresponding Image token should relate highest with its corresponding text. This is how it will be done for all the Images in the batch, but let’s take a look at the individual token.

Figure 6: The blue cell in the final vector is just an example and doesn’t correspond to the maximum element Source: Paper

Here, the figure is not really different, we take the image embedding I and compute the dot product with each text embedding in the batch. For instance, when we use I3, we want it to align most strongly with the corresponding text embedding T3​ in the batch. Ideally, the highest value in the I3​ row should be the dot product I3⋅T3, doing the same thing batch-wise, would look as if we are maximizing all the diagonal elements, where each In​ aligns best with its corresponding Tn. To achieve this, we use a loss function that measures how well the maximum value in each row stands out from the other values. This is effectively done by taking both the row and column wise cross entropy loss.

from vit import VissionTransformer # Importing ViT from previous implementaton (GitHub: Ml-Models)
import numpy as np
import torch.nn.functional as F

class CLIPModel(nn.Module):
def __init__(self):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
ViT = VissionTransformer(
num_layers=8,
img_size=224,
emb_size=768,
patch_size=16,
num_head=6,
num_class=False).to(self.device)
self.image_encoder = ImageEncoder(base_model=ViT, embed_dim=768, proj_dim=256)
self.text_encoder = TextEncoder(embed_dim=768, proj_dim=256)

self.temperature = nn.Parameter(torch.ones([])*np.log(1/7)).to(self.device)

def forward(self, x):
I_t = self.image_encoder(x["image"])
T_t = self.text_encoder(x["input_ids"], x["mask"])

logits = I_t@T_t.T * torch.exp(self.temperature)

labels = torch.arange(I_t.size(0)).to(self.device)

loss_I = F.cross_entropy(logits.T, labels)
loss_T = F.cross_entropy(logits, labels)

loss = (loss_I + loss_T)/2.0

return loss, logits
  1. We get I_t and T_t (size: B, Token_Size, Embed_Size)
  2. We calculate the logits by taking the dot product as previously discussed and then multiplying them by the exponent of a temperature parameter. If you’re familiar with contrastive learning or have read my article on DINO (Distillation with No Labels), you might know that dividing by temperature is typically used to sharpen the output distribution. However, instead of directly dividing by the temperature, we multiply by a trainable tensor, which is set using nn.Parameter() and initialized as log(1/7). Since eln(x)=x, then exp(log(1/T)) should be 1/T, you might wonder why we don’t simply multiply by 1/T. The reason is that using log(1/T) instead makes it easier for the optimizer to calculate and update gradients during training. This approach is a common practice in deep learning as it leads to smoother training and more stable updates to the model's weights
  3. The labels are simply generated with the batch size ([0, 1,..N]). As we discussed previously the goal is to maximize each diagonal element (i1T1, i2T2,..inTn) thus the labels for each row in the entire matrix are [0, 1, 2, ..N], corresponding to which element in the row is to be maximum.
  4. As given in the pseudo-code the embeddings are normalized but we don’t have to do that as we have applied a layer norm while returning outputs for both Image and Text encoder.
  5. Following the pseudo-code, the cross entropy loss is computed for both rows and columns. We do it by passing the transpose of the logits and the normal logits along with labels, taking the average of both the losses, we now have our final resultant Loss.

Setting up the Model

texts = ["This is a sample sentence.", "This is another example."]

# You can Use a CustomDataset as we Implemented above for training
train_data = CustomDataset(texts, image_path)
train_loader = DataLoader(train_data, batch_size, shuffle=True)

# Example Usage
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
inputs= tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)

test = {
"image" : torch.rand(2, 3, 224, 224).to(device),
"input_ids" : inputs["input_ids"],
"mask" : inputs["attention_mask"]
}

model = CLIPModel().to(device)
loss, logits = model(test)
print("Loss:", loss, "Logits:", logits)

And this is it!

If you using a custom dataset and dataloader, then obviously there’s no need to set up the tokenizer, though I have formatted a test input for the sake of clarity. I also have a GitHub repository where I upload all the models that we’ve implemented. You can find the entire code for this article here: https://github.com/mishra-18/ML-Models/blob/main/clip.py.

If you liked my work or found this article helpful, please consider giving me a Follow or dropping some claps 👏. This really motivates me to continue writing quality articles and keep uploading meaningful stories.

Thanks for reading.

--

--