Building CLIP From Scratch

Open World Object Recognition on the FashionMNIST Dataset

Matt Nguyen
Correll lab
19 min readMay 16, 2024

--

Computer vision systems were historically limited to a fixed set of classes, CLIP has been a revolution allowing open world object recognition by “predicting which image and text pairings go together". CLIP is able to predict this by learning the cosine similarity between image and text feature for batches of training data. This is shown in the contrastive pre-training portion of Figure 1 where the dot product between the image features {I_1 … I_N} and the text features {T_1 … T_N} is taken.

In this tutorial, we are going to build CLIP from scratch and test it on the fashion MNIST dataset. Some of the sections in this article are taken from my vision transformers article, which will help you to understand how transformer models can be applied to images.

Notebook with the code from this tutorial can be found here.

Figure 1: CLIP Model Overview. Image: CLIP Paper.

Import Libraries and Modules

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np

We are going to be building CLIP using PyTorch, so we will need to import the library plus others that we will be using in this tutorial.

import torch
import torch.nn as nn

We also need to import in torchvision.transforms in order to resize the input images and convert them to tensors. Resizing the input images is optional. You just need to make sure that the image size is divisible by the patch size.

import torchvision.transforms as T

We are going to be using Adam for our optimizer, so we need to import it in from torch.optim.

from torch.optim import Adam

We are going to be importing in the fashion MNIST dataset from HuggingFace for this tutorial, so we need to import in datasets. We are going to be using the DataLoader from PyTorch to help load the data, so we need to import that in as well.

from datasets import load_dataset
from torch.utils.data import DataLoader

We are going to import in matplotlib in order to display images when doing zero-shot classification.

import matplotlib.pyplot as plt

Finally, we need to import in numpy which we will use to perform sin and cosine when creating the positional encodings.

import numpy as np

Image and Text Encoders

We will first build the image and text encoders. Both embed images and text, respectively, into a single token, which can then be used in the contrastive loss computation. If you already know how image and text encoding work, you can skip right to the Section “CLIP Model” further down.

Positional Embedding

class PositionalEmbedding(nn.Module):
def __init__(self, width, max_seq_length):
super().__init__()

pe = torch.zeros(max_seq_length, width)

for pos in range(max_seq_length):
for i in range(width):
if i % 2 == 0:
pe[pos][i] = np.sin(pos/(10000 ** (i/width)))
else:
pe[pos][i] = np.cos(pos/(10000 ** ((i-1)/width)))

self.register_buffer('pe', pe.unsqueeze(0))

def forward(self, x):

x = x + self.pe

return x
Figure 2: Changing patch order can change an O into an X. Image: ViT From Scratch Article.

Unlike models like LSTMs which take embeddings in sequentially, transformers take embeddings in parallel. While this increases the speed, transformers are not aware of what order sequences are supposed to be in. This is a problem because changing the order of the sequence would most likely alter its meaning. An example of this is Figure 2 which shows that changing an image’s patch order can change the image from an O to something that more closely resembles an X. In order to fix this problem, positional encodings need to be added to the embeddings. Each positional encoding is unique to the position that it represents which allows model to identify which position each embedding is supposed to go. In order for the positional encodings to be added to the embeddings, they have to have the same dimension, d_model. We get the positional encodings by using the equations in Figure 3.

Figure 3: Positional encoding equations. Image: Transformer Paper.
pe = torch.zeros(max_seq_length, width)

for pos in range(max_seq_length):
for i in range(width):
if i % 2 == 0:
pe[pos][i] = np.sin(pos/(10000 ** (i/width)))
else:
pe[pos][i] = np.cos(pos/(10000 ** ((i-1)/width)))

self.register_buffer('pe', pe.unsqueeze(0))

Notice that pe is only a local variable and only gets added to the class using the register_buffer method. This way, the positional encoding become a non-trainable part of the model.

In the forward method, the positional encodings that we calculated above are added to the input.

x = x + self.pe

return x

Attention Head

class AttentionHead(nn.Module):
def __init__(self, width, head_size):
super().__init__()
self.head_size = head_size

self.query = nn.Linear(width, head_size)
self.key = nn.Linear(width, head_size)
self.value = nn.Linear(width, head_size)

def forward(self, x, mask=None):
# Obtaining Queries, Keys, and Values
Q = self.query(x)
K = self.key(x)
V = self.value(x)

# Dot Product of Queries and Keys
attention = Q @ K.transpose(-2,-1)

# Scaling
attention = attention / (self.head_size ** 0.5)

# Applying Attention Mask
if mask is not None:
attention = attention.masked_fill(mask == 0, float("-inf"))

attention = torch.softmax(attention, dim=-1)

attention = attention @ V

return attention
Figure 4: Scaled Dot-Product Attention and Multi-Head Attention diagrams. Image: Transformer Paper.

Transformers use attention which is a communication mechanism that allows the model to focus on important parts of an image. Attention scores can be calculated using the equation in Figure 5.

Figure 5: Attention equation. Image: Transformer Paper.

The first step in calculating attention is obtaining the queries, keys, and values of the tokens. The query of a token is what the token is looking for, the key is what the token contains, and the value is what is communicated between the tokens. The queries, keys, and values can be calculated by passing tokens through linear modules.

def forward(self, x):
# Obtaining Queries, Keys, and Values
Q = self.query(x)
K = self.key(x)
V = self.value(x)

We are able to get the relationship between the tokens in a sequence by getting the dot product of the queries and keys.

# Dot Product of Queries and Keys
attention = Q @ K.transpose(-2,-1)

We need to scale these values to control variance at initialization so that tokens are able to aggregate information from more than one other token. Scaling is applied by dividing the dot product by the square root of the size of the attention head.

# Scaling
attention = attention / (self.head_size ** 0.5)

The main difference between transformer encoders and decoders is that decoders apply an attention mask while the encoders do not. While CLIP is an encoder only model, a mask still needs to be applied with the text encoder due to the padding that is applied to the input text during tokenization. Note that the mask is optional, so this Attention head can be used in both the text and vision encoder.

Figure 6: Example of applying mask to attention scores. Image: Own work.
# Applying Attention Mask
if mask is not None:
attention = attention.masked_fill(mask == 0, float("-inf"))

We then need to apply a SoftMax operation on the scaled dot product. Here, values with negative infinity will simply be ignored.

attention = torch.softmax(attention, dim=-1)

Finally, we need to get the dot product between the SoftMax and the values matrix. This is essentially communicating the information between the corresponding tokens.

attention = attention @ V

return attention

Multi-Head Attention

class MultiHeadAttention(nn.Module):
def __init__(self, width, n_heads):
super().__init__()
self.head_size = width // n_heads

self.W_o = nn.Linear(width, width)

self.heads = nn.ModuleList([AttentionHead(width, self.head_size) for _ in range(n_heads)])

def forward(self, x, mask=None):
# Combine attention heads
out = torch.cat([head(x, mask=mask) for head in self.heads], dim=-1)

out = self.W_o(out)

return out

Multi-head attention is just running multiple heads of self-attention in parallel and combining them. We can do this by adding the attention heads into a module list,

self.heads = nn.ModuleList([AttentionHead(width, self.head_size) for _ in range(n_heads)])

passing through the input and concatenating the results.

def forward(self, x, mask=None):
# Combine attention heads
out = torch.cat([head(x, mask=mask) for head in self.heads], dim=-1)

We then need to pass the output through another linear module.

out = self.W_o(out)

return out

Transformer Encoder

class TransformerEncoder(nn.Module):
def __init__(self, width, n_heads, r_mlp=4):
super().__init__()
self.width = width
self.n_heads = n_heads

# Sub-Layer 1 Normalization
self.ln1 = nn.LayerNorm(width)

# Multi-Head Attention
self.mha = MultiHeadAttention(width, n_heads)

# Sub-Layer 2 Normalization
self.ln2 = nn.LayerNorm(width)

# Multilayer Perception
self.mlp = nn.Sequential(
nn.Linear(self.width, self.width*r_mlp),
nn.GELU(),
nn.Linear(self.width*r_mlp, self.width)
)


def forward(self, x, mask=None):
x = x + self.mha(self.ln1(x), mask=mask)

x = x + self.mlp(self.ln2(x))

return x
Figure 7: Transformer Encoder Diagram. Image: Own work.

The transformer encoder is made up of two sub-layers: the first sub-layer performs multi-head attention and the second sub-layer contains a multi-layer perceptron. The multi-head attention sub-layer performs communication between tokens while the multi-layer perceptron sub-layers allows the tokens to individually “think” on what was communicated to them.

Layer normalization is an optimization technique that normalizes each input in the batch independently across its features. For our model, we will pass our inputs through a layer norm module at the beginning of each sub-layer.

# Sub-Layer 1 Normalization
self.ln1 = nn.LayerNorm(width)

# Sub-Layer 2 Normalization
self.ln2 = nn.LayerNorm(width)

The MLP will consist of two linear layers with a GELU layer in between. GELU is used instead of RELU because it doesn’t have RELU’s limitation of being non-differentiable at zero.

# Multilayer Perception
self.mlp = nn.Sequential(
nn.Linear(width, width*r_mlp),
nn.GELU(),
nn.Linear(width*r_mlp, width)
)

In the forward method for the encoder, the input is passed through the first layer normalization module before performing multi-head attention. The original input is added to the output from performing multi-head attention to create a residual connection.

This is then passed through another layer normalization module before being inputted into the MLP. Another residual connection is created by adding the output from the MLP to the out from the first residual connection.

The residual connections are used to help prevent the vanishing gradient problem by creating a path for the gradient to be back-propagated unimpeded back to the original input.

def forward(self, x):
# Residual Connection After Sub-Layer 1
out = x + self.mha(self.ln1(x))

# Residual Connection After Sub-Layer 2
out = out + self.mlp(self.ln2(out))

return out

Tokenization

def tokenizer(text, encode=True, mask=None, max_seq_length=32):
if encode:
out = chr(2) + text + chr(3) # Adding SOT and EOT tokens
out = out + "".join([chr(0) for _ in range(max_seq_length-len(out))]) # Adding Padding
out = torch.IntTensor(list(out.encode("utf-8"))) # Encoding Text
mask = torch.ones(len(out.nonzero()))
mask = torch.cat((mask,torch.zeros(max_seq_length-len(mask)))).type(torch.IntTensor)
else:
out = [chr(x) for x in text[1:len(mask.nonzero())-1]]
out = "".join(out)
mask = None

return out, mask

Transformers are unable to process raw text, so the first thing that we need to do is tokenize the input strings before passing them through the text encoder.

In this tutorial, we are going to be doing a simple version of tokenization where we just use the UTF-8 encoding. We are able to naively use the UTF-8 encoding for tokenization because we are only going to be using simple text for our examples. For more complex examples, you may want to use a BPE tokenizer. This is because with UTF encoding, you have a max vocab size of 256, which means that with more complex examples you may have longer input sequences which would be inefficient when doing attention due to limited context length.

Figure 8: Tokenization Process with Max Sequence Length of 10. Image: Own work.

The first step for our tokenizer is to add the start of text and end of text tokens to the input string.

text = chr(2) + text + chr(3)

After adding the start of text and end of text tokens, we need to pad the length of the sequence to the maximum sequence length.

text = text + "".join([chr(0) for _ in range(10-len(text))])

We complete the tokenization by encoding the text sequence to UTF-8 and converting the output to an IntTensor.

text = torch.IntTensor(list(text.encode("utf-8")))

After tokenizing the text, we need to create a mask for the text. While the mask that is normally used in transformers is used to ensure that tokens do not communicate with future tokens, the mask that we are applying here just makes it so that padding tokens are ignored. Because of this, the mask is just going to be a tensor of size equal to the max sequence length where the elements is 0 where there is padding and 1 otherwise.

mask = torch.ones(len(text.nonzero()))
mask = torch.cat((mask,torch.zeros(10-len(mask)))).type(torch.IntTensor)

Text Encoder

class TextEncoder(nn.Module):
def __init__(self, vocab_size, width, max_seq_length, n_heads, n_layers, emb_dim):
super().__init__()

self.max_seq_length = max_seq_length

self.encoder_embedding = nn.Embedding(vocab_size, width)

self.positional_embedding = PositionalEmbedding(width, max_seq_length)

self.encoder = nn.ModuleList([TransformerEncoder(width,n_heads) for _ in range(n_layers)])

# learned proj of image to embed
self.projection = nn.Parameter(torch.randn(width, emb_dim))

def forward(self, text, mask=None):
# Text Embedding
x = self.encoder_embedding(text)

# Positional Embedding
x = self.positional_embedding(x)

# Transformer Encoder
for encoder_layer in self.encoder:
x = encoder_layer(x, mask=mask)

# Takes features from the EOT Embedding
x = x[torch.arange(text.shape[0]), torch.sub(torch.sum(mask[:,0],dim=1),1)]

# joint multimodal embedding
if self.projection is not None:
x = x @ self.projection

x = x / torch.norm(x, dim=-1, keepdim=True)

return x

For the text encoder, we are going to be using a regular transformer model. The first step in creating the text encoder is creating an embedding table of size (vocab_size, width). This embedding table contains a vector representation with a size equal to the width of the transformer model for each token in the vocabulary.

self.encoder_embedding = nn.Embedding(vocab_size, width)

Before outputting the results for the transformer, we are going to need to embed the features in the joint embedding space. We are going to do this by getting the dot product of the text features and a learned projection that we create by using a nn.Parameter.

# learned proj of image to embed
self.projection = nn.Parameter(torch.randn(width, emb_dim))

In the forward method, the first thing that we are going to do is pass the text tokens through the embedding table.

# Text Embedding
x = self.encoder_embedding(text)

We then need to add the positional encodings to the output of the embedding table.

# Positional Embedding
x = self.positional_embedding(x)

With the positional encodings added, we can now pass it through the encoder layers along with the masks.

# Transformer Encoder
for encoder_layer in self.encoder:
x = encoder_layer(x, mask=mask)

The output of the encoder layers is the text features. We are going to be using the features from the EOT embedding. If you are using something like the BERT model for the text encoder, you would want to use the class tokens here instead.

# Takes features from the EOT Embedding
x = x[torch.arange(text.shape[0]),torch.sub(torch.sum(mask[:,0],dim=1),1)]

Finally, we embed the text features in the joint embedding space by getting the dot product between the features and the learned projection and normalize it by dividing by the normalized dot product.

# joint multimodal embedding
if self.projection is not None:
x = x @ self.projection

x = x / torch.norm(x, dim=-1, keepdim=True)

return x

Image Encoder

class ImageEncoder(nn.Module):
def __init__(self, width, img_size, patch_size, n_channels, n_layers, n_heads, emb_dim):
super().__init__()

assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, "img_size dimensions must be divisible by patch_size dimensions"
assert width % n_heads == 0, "width must be divisible by n_heads"

self.n_patches = (img_size[0] * img_size[1]) // (patch_size[0] * patch_size[1])

self.max_seq_length = self.n_patches + 1

self.linear_project = nn.Conv2d(n_channels, width, kernel_size=patch_size, stride=patch_size)

self.cls_token = nn.Parameter(torch.randn(1, 1, width))

self.positional_embedding = PositionalEmbedding(width,self.max_seq_length)

self.encoder = nn.ModuleList([TransformerEncoder(width,n_heads) for _ in range(n_layers)])


# learned proj of image to embed
self.projection = nn.Parameter(torch.randn(width, emb_dim))


def forward(self,x):
# Patch Embedding
x = self.linear_project(x)
x = x.flatten(2).transpose(1, 2)

# Positional Embedding
x = torch.cat((self.cls_token.expand(x.size()[0], -1, -1),x), dim=1)
x = self.positional_embedding(x)

# Transformer Encoder
for encoder_layer in self.encoder:
x = encoder_layer(x)

# Getting Class Tokens
x = x[:, 0, :]

# joint multimodal embedding
if self.projection is not None:
x = x @ self.projection

x = x / torch.norm(x, dim=-1, keepdim=True)

return x
Figure 9: Vision Transformer Model Diagram. Image: ViT Paper.

For the image encoder, we are going to be using a vision transformer. When creating our image encoder, we first need to make sure that the input images can be split evenly into patches of size patch_size and that the dimensionality of the model is divisible by the number of attention heads.

assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, "img_size dimensions must be divisible by patch_size dimensions"
assert width % n_heads == 0, "width must be divisible by n_heads"

We also need to calculate the maximum sequence length for the positional encoding, which will be equal to the number of patches plus one. The number of patches can be found by dividing the product of the height and width of the input image by the product of the height and width of the patch size.

self.n_patches = (self.img_size[0] * self.img_size[1]) // (self.patch_size[0] * self.patch_size[1])
self.max_seq_length = self.n_patches + 1

The vision transformer will also need to be able to have multiple encoder modules. This can be achieved by putting a list of encoder layers inside of a ModuleList.

self.encoder = nn.ModuleList([TransformerEncoder(width,n_heads) for _ in range(n_layers)])

Before we pass are inputs through the encoder layers, we first need to split the input image into patches and create a sequence of linear embeddings of these patches. We are able to achieve this by using PyTorch’s Conv2d method

The Conv2d method takes the input images, splits them into patches and provides a linear projection of a size equal to the width of the model. By setting kernel_size and stride to patch size, we ensure that the patches are the correct size and there is no overlap.

self.linear_project = nn.Conv2d(n_channels, width, kernel_size=patch_size, stride=patch_size)

In the forward method we pass through the input that has shape (B, C, H, W) through the linear_project/Conv2D method and receives an output of shape (B, d_model, P_col, P_row).

def forward(self, x):
x = self.linear_project(x) # (B, C, H, W) -> (B, width, P_col, P_row)
Figure 10: Conv2D applied on a single image. Each color represents which patch an element belongs to. Image: ViT From Scratch Article.

We use the flatten method to combine the patch column and patch row dimensions into a single patch dimension giving us a shape of (B, d_model, P)

x = x.flatten(2) # (B, width, P_col, P_row) -> (B, width, P)
Figure 11: Flatten applied to Conv2d output. Image: ViT From Scratch Article.

Finally we use the transpose method to switch the d_model and patch dimensions to get a shape of (B, P, d_model).

x = x.transpose(-2, -1) # (B, width, P) -> (B, P, width)
Figure 12: Transpose applied to flatten output. Image: ViT From Scratch Article.

Vision transformers uses the standard approach of adding a learnable classification token to the patch embeddings in order to perform classification.

self.cls_token = nn.Parameter(torch.randn(1, 1, width))

Each image in the batch needs to have class token, so we are going to use the expand function in order to use self.cls_token to create class tokens for every image in the batch.

x = torch.cat((self.cls_token.expand(x.size()[0], -1, -1),x), dim=1)

After adding the class tokens, we need to add the positional encodings to the embeddings.

x = self.positional_embedding(x)

With the positional encodings added, we can now pass the embeddings through the encoder layers.

# Transformer Encoder
for encoder_layer in self.encoder:
x = encoder_layer(x)

From the output of the encoder layers, we only need the information from the learned class tokens.

# Getting Class Tokens
x = x[:, 0, :]

Finally, we embed the image features in the joint embedding space by getting the dot product between the features and the learned projection and normalize it by dividing by the normalized dot product.

if self.projection is not None:
x = x @ self.projection

x = x / torch.norm(x, dim=-1, keepdim=True)

return x

CLIP Model

class CLIP(nn.Module):
def __init__(self, emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers):
super().__init__()

self.image_encoder = ImageEncoder(vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, emb_dim)

self.text_encoder = TextEncoder(vocab_size, text_width, max_seq_length, text_heads, text_layers, emb_dim)

self.temperature = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def forward(self,image,text, mask=None):
I_e = self.image_encoder(image)
T_e = self.text_encoder(text, mask=mask)

# scaled pairwise cosine similarities [n, n]
logits = (I_e @ T_e.transpose(-2,-1)) * torch.exp(self.temperature)

# symmetric loss function
labels = torch.arange(logits.shape[0]).to(self.device)

loss_i = nn.functional.cross_entropy(logits.transpose(-2,-1), labels)
loss_t = nn.functional.cross_entropy(logits, labels)

loss = (loss_i + loss_t) / 2

return loss

When given a batch of images and captions, CLIP is supposed to tell you which captions goes with which images. It does this by training the text and image encoder together to maximize the pairwise cosine similarity scores of the pairs that are supposed to go together and minimizing the pairs that are not supposed to go together.

To do this, we first need to get the embedded features from the image and text encoders.

def forward(self,image,text, mask=None):
I_e = self.image_encoder(image)
T_e = self.text_encoder(text, mask=mask)

Using the embedded features, we can calculate the scaled pairwise cosine similarities by using a dot product between the embedded image features and a transposed version of the embedded text features. The cosine similarities should be maximized along the diagonal in the figure where the correct image and text are paired together.

Figure 13: Calculating cosine similarity. Image: CLIP Paper.
logits = (I_e @ T_e.transpose(-2,-1)) * torch.exp(self.temperature)

This works as I_e and T_e are each containing N batches, resulting in the matrix shown in Figure 13. In order to maximize the cosine similarity between related images, CLIP uses symmetric/contrastive loss. We can calculate this loss by first creating labels that correspond to the items in the batch.

# symmetric loss function
labels = torch.arange(logits.shape[0]).to(self.device)

We then calculate the cross entropy loss along the rows of the logits to get the loss for the images.

loss_i = nn.functional.cross_entropy(logits.transpose(-2,-1), labels)

The loss for the text is calculated by calculating the cross entropy loss along the columns.

loss_t = nn.functional.cross_entropy(logits, labels)

We get the final loss by calculating the average between the loss for the images and the loss for the text.

loss = (loss_i + loss_t) / 2

return loss

Dataset

class FashionMNIST(Dataset):
def __init__(self, train=True):
self.dataset = load_dataset("fashion_mnist")

self.transform = T.ToTensor()

if train:
self.split = "train"
else:
self.split = "test"


self.captions = {0: "An image of a t-shirt/top",
1: "An image of trousers",
2: "An image of a pullover",
3: "An image of a dress",
4: "An image of a coat",
5: "An image of a sandal",
6: "An image of a shirt",
7: "An image of a sneaker",
8: "An image of a bag",
9: "An image of an ankle boot"}


def __len__(self):
return self.dataset.num_rows[self.split]

def __getitem__(self,i):
img = self.dataset[self.split][i]["image"]
img = self.transform(img)

cap, mask = tokenizer(self.captions[self.dataset[self.split][i]["label"]])

mask = mask.repeat(len(mask),1)

return {"image": img, "caption": cap, "mask": mask}

For this tutorial, we are going to be using the Fashion MNIST dataset from HuggingFace. We have chosen this dataset because it is rather small and keeps training time reasonable.

self.dataset = load_dataset("fashion_mnist")

For each entry in the dataset, we are going to need three things: the image, caption, and text mask.

For the image, the only change we need to make is to transform the image to a tensor.

img = self.dataset[self.split][i]["image"]
img = self.transform(img)

For the caption, we need to pass it through the tokenizer that we created to get the token representation along with the mask for the tokens.

cap, mask = tokenizer(self.captions[self.dataset[self.split][i]["label"]])

The mask that we get from the tokenizer has is a 1D tensor of size max_seq_length. In the text encoder, the mask is going to be applied to the attention scores which has a shape of (max_seq_length, max_seq_length). Because of this, we need to expand the mask so that it is applied to each row of the attention scores.

Figure 14: Mask before and after expanding. Image: Own work.
mask = mask.repeat(len(mask),1)

The image, caption, and mask are held in the dataset as a dictionary.

return {"image": img, "caption": cap, "mask": mask}

Training Parameters

emb_dim = 32
vit_width = 9
img_size = (28,28)
patch_size = (14,14)
n_channels = 1
vit_layers = 3
vit_heads = 3
vocab_size = 256
text_width = 32
max_seq_length = 32
text_heads = 8
text_layers = 4
lr = 1e-3
epochs = 10
batch_size = 128

Loading Dataset

train_set = FashionMNIST(train = True)
test_set = FashionMNIST(train = False)

train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)

Training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)

best_loss = np.inf
for epoch in range(epochs):
for i, data in enumerate(train_loader, 0):
img, cap, mask = data["image"].to(device), data["caption"].to(device), data["mask"].to(device)
loss = model(img,cap,mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Epoch [{epoch+1}/{epochs}], Batch Loss: {loss.item():.3f}")

# Saves model if it performed better than the previous best
if loss.item() <= best_loss:
best_loss = loss.item()
torch.save(model.state_dict(), "/content/drive/MyDrive/clip.pt")
print("Model Saved.")

Testing

# Loading Best Model
model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers).to(device)
model.load_state_dict(torch.load("/content/drive/MyDrive/clip.pt", map_location=device))

# Getting dataset captions to compare images to
text = torch.stack([tokenizer(x)[0] for x in test_set.captions.values()]).to(device)
mask = torch.stack([tokenizer(x)[1] for x in test_set.captions.values()])
mask = mask.repeat(1,len(mask[0])).reshape(len(mask),len(mask[0]),len(mask[0])).to(device)

correct, total = 0,0
with torch.no_grad():
for data in test_loader:
images, labels = data["image"].to(device), data["caption"].to(device)
image_features = model.image_encoder(images)
text_features = model.text_encoder(text, mask=mask)

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
_, indices = torch.max(similarity,1)
pred = torch.stack([tokenizer(test_set.captions[int(i)])[0] for i in indices]).to(device)
correct += int(sum(torch.sum((pred==labels),dim=1)//len(pred[0])))
total += len(labels)

print(f'\nModel Accuracy: {100 * correct // total} %')

We tested the model by getting the captions that the model was trained on and comparing it to the actual captions. When training, we used the same caption template (“An image of a(n) {class}”), so this testing stage is pretty much the same as any other image classifier. We achieved a model accuracy of around 85%.

Zero-Shot Classification

# Loading Best Model
model = CLIP(emb_dim, vit_width, img_size, patch_size, n_channels, vit_layers, vit_heads, vocab_size, text_width, max_seq_length, text_heads, text_layers).to(device)
model.load_state_dict(torch.load("/content/drive/MyDrive/clip.pt", map_location=device))


# Captions to compare images to
class_names =["t-shirt/top",
"trousers",
"pullover",
"dress",
"coat",
"sandal",
"shirt",
"sneaker",
"bag",
"ankle boot"]

text = torch.stack([tokenizer(x)[0] for x in class_names]).to(device)
mask = torch.stack([tokenizer(x)[1] for x in class_names])
mask = mask.repeat(1,len(mask[0])).reshape(len(mask),len(mask[0]),len(mask[0])).to(device)

idx = 1000

img = test_set[idx]["image"][None,:]
plt.imshow(img[0].permute(1, 2, 0) ,cmap="gray")
plt.title(tokenizer(test_set[idx]["caption"], encode=False, mask=test_set[idx]["mask"][0])[0])
plt.show()
img = img.to(device)
with torch.no_grad():
image_features = model.image_encoder(img)
text_features = model.text_encoder(text, mask=mask)


image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
print(f"{class_names[int(index)]:>16s}: {100 * value.item():.2f}%")

For zero-shot classification, we are comparing the image to just the class names. We input in the labels to compare against the image and it will return the top 5 predictions with the predicted likelihoods. This isn’t the greatest example of CLIP performing zero-shot classification. Using the fashion MNIST dataset makes the model easy to train, but captions are not very rich. To truly appreciate the zero-shot capabilities of CLIP, a training set with multiple nouns would be more appropriate. True zero-shot detection would then allow to detect previously unseen permutations.

Figure 15: Zero-shot classification output. Image: Own work.

--

--