Clustering Image Embeddings
As humans, we have a great understanding of similarity between 2 images having the same objects or in a similar setting. We easily make sense of the 2 images below, but making a computer do the same thing might be tricky.
Algorithms or computers in general need numbers to work with. This naturally brings us to the thought that an image should be converted to a list of numbers. This list can be compared with other lists to understand the similarity of this image with others.
WHAT ARE EMBEDDINGS?
These images look similar, but they have some small but important differences as far as our lists are concerned. If we were to make a simple list of numbers or a vector of numbers signifying the features in some way, it would go like this:
Let’s try to understand the figure above. The first image has a watch on the person’s hand. In the second image, there is no watch. Similarly, the first image has a rear view mirror whereas the first one doesn't. Both images have a human hand and steering wheel in the image. This has been covered in the vectors shown above. That is in essence what we intend to capture in an embedding or vector representation. This embedding translates to similarity as perceived by humans in the vector space.
An embedding is considered good when the semantic similarity is quantifiable in terms of the distance between the embeddings in vector space. Once we have those embeddings, we can try finding similarity between 2 embeddings. Taking this up a notch, we can even group semantically similar embeddings. The number of groups to be formed is a human decided variable. This process is called clustering, and methods used often are K-means clustering or hierarchical clustering. For this post, we will be using K-means and exploring what kind of images are grouped together.
The 1/0 embedding shown above can easily be understood in terms of features. We will be exploring embedding creation by both CNN and transformers(ViT). The embeddings generated by these are large and dense (non-zero numbers). More often, CNNs are used for visual data, but lately transformers are gaining a lot of attention in the CV space as well. CNNs process pixel arrays, whereas transformers split the image into visual tokens. Transformers uses self attention mechanism which helps to model long-range multi level dependencies across image regions which is pretty different from a CNN. Thus, the results from these approaches may differ, making it worth to try them both.
OUR DATA
For the sake of this blog post, we will be using 300 random YouTube videos from the channel MovieClips. These are small clips from movies of different genres and various movie stars.
Once we break down the videos into frames, we get a huge number of frames since most videos are 30 fps. But, most frames are redundant since there is not a ton of change in the scene within a span of 0.5 seconds. For this, we developed a key-frame extraction algorithm (currently out of scope of this blogpost) which reduced the number of frames to be processed by around 99% with respect to all the frames in the video. These key frames are what we will be using going forward for testing our embeddings.
THE CNN WAY:
Convolutional Neural Network (CNN) is a class of deep learning architectures that are usually applied to visual data. They can be used to transform images into embeddings. For our experiment, we will use EfficientNet which is a type of CNN architecture by Google aiming to be extremely efficient in terms of compute resources as the name suggests. We can use a pre-trained EfficientNet model and use the model just until the last dense layer to get a 1280 dimensional embedding. The code for this is as follows:
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
efficientnet = torch.hub.load(
'NVIDIA/DeepLearningExamples:torchhub',
f'nvidia_efficientnet_b0',
pretrained=True
)
echildren = list(efficientnet.children())
self.feature_extractor = torch.nn.Sequential(*echildren[:-1])
self.pre_final_stage = torch.nn.Sequential(*echildren[-1][:-1])
def forward(self, data):
feats = self.feature_extractor(data)
feats = self.pre_final_stage(feats)
return feats
The forward pass here will return the embedding for an image after the necessary pre-processing steps have been completed. And this way we have an embedding for all the key frames we have from the clips. Now we will try to cluster them using K-means. This goes as follows:
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
embeddings_standardized = StandardScaler().fit_transform(embeddings)
kmeanModel = KMeans(n_clusters=num_clusters).fit(embeddings_standardized)
Using the built model, we can look at cluster centers and also gather a few of the embeddings nearest to the center to look at how similar those images are visually. Will clusters of these embeddings really be semantically similar as they would to a human? Here are some of the examples after K-means with num_clusters as 100.
These results are amazing, as we would have grouped them as visually similar as well. This confirms our thought that CNN created embeddings really do work and can help group into useful categories like hats or people sitting inside cars. Let’s try this experiment with transformer based embeddings now.
THE TRANSFORMER WAY:
As presented here, ViTMAE is a recent paper which uses a Vision transformer to reconstruct pixel values for masked patches. Vanilla ViT-Huge model achieves the best accuracy (87.8%) among methods that use only ImageNet-1K data. The Huggingface link explains how to use a pre-trained ViTMAE model for getting embeddings out of an image as the code in image attached below:
from transformers import AutoImageProcessor, ViTMAEModel
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base")
model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state
The problem here is, this can’t be used directly since the output is of the shape 50x768 owing to the 50 heads of the transformer. In essence, processing 224x224 images into a 50x768 matrix isn’t really reducing the number of bytes taken to store the image representation, and we need a smaller embedding to run K-means clustering on it in a reasonable time for this experiment.
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
self.model = ViTMAEModel.from_pretrained("facebook/vit-mae-base")
def forward(self, inputs):
inputs.pop('image_name', None)
outputs = self.model(**inputs)
return torch.sum(outputs.last_hidden_state, dim=1)
To tackle this, a simple but effective method is summing up across the 50 heads as in the code attached above. This gives us a reasonable sized 768 dimensional embedding which will be used for clustering ahead. Now going ahead with the Kmeans clustering process similar to the way we did above with CNN method, these are some of the results.
There are some interesting clusters from ViTMAE embeddings like explosions and people with weapons, but there are tons of clusters in this case which make little sense where images are of different nature visually. This is possibly due to the sum across all transformer heads making it less accurate. Some of the unexpected clusters where images seems very different semantically are:
CONCLUSION:
So there you go! As per our experiments above, an image feature extractor works well when paired with a clustering algorithm like K-means. Embeddings can do a brilliant job in finding semantically similar images. It seems like we need to change our approach for transformer based model to return embeddings which are short in itself and don’t have to averaged to match in performance with CNN based embeddings. Overall, this can be a holistic approach to explore what we have in our video corpus.
On another note, we can also flip the process by projecting the query image into semantic space and finding embeddings nearby. This helps us find similar images to a specific image from a large corpus of images. At GumGum, we use this technique to harvest semantically similar images to reduce annotation efforts when in need of images pertaining to a specific category for model training.
The same concepts can be applied for text and audio as well. Have fun doing your own experiments with embeddings.