Semantic Search: Hands On Text to Image Search Using CLIP and FAISS

Shashank Vats
𝐀𝐈 𝐦𝐨𝐧𝐤𝐬.𝐢𝐨
5 min readJun 19, 2023
Photo by Evgeni Tcherkasski on Unsplash

Semantic search revolutionizes the process of information retrieval by going beyond the mere matching of keywords, focusing instead on deciphering both the searcher’s intent and the contextual meaning of the terms used in the query within the digital data space. This approach, while considering the presence of keywords, delves deeper to interpret context, detect synonyms and homonyms, recognize hierarchies, and most critically, comprehend the intent behind the words in the query.

In a digital landscape that’s teeming with visual content, it can be overwhelming to navigate and locate the exact images you need. This is where the magic of semantic search becomes particularly invaluable. It provides a robust and smart strategy to sift through massive volumes of images efficiently and accurately, enabling you to uncover the precise visual content that aligns with your search intent.

In this article, we’ll embark on a fascinating journey that will help you understand and implement a straightforward yet remarkably effective method for performing semantic searches within large collections of images.

For this exercise, we’ll be using CLIP model from OpenAI to build the embeddings and Facebook’s FAISS library for Indexing. We’d also be using Flicker 30k Dataset available on Kaggle.

If you want to know more about how CLIP works or how embeddings can be used to search in vector space, I have explained them thoroughly in my previous posts.

Imports and Setup

Let's import some packages that we’d be needing for our exercise:

import os
from PIL import Image
import PIL
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle

import torch
from datasets import Dataset, Image
from torch.utils.data import DataLoader

from typing import List, Union, Tuple

from transformers import CLIPProcessor, CLIPModel

import faiss

Load the model

Next, load the pre-trained CLIP model and its associated processor and set the device. Here, I’m setting my device as cpu but if you have cuda available on your device, you can set it the same.

Let's also get the image path.

device = "cpu" #setting the device

# loading CLIP model and its processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# getting image paths
image_path = ['./flicker30k/Images/' + path for path in image_path if '.jpg' in path]
image_path.sort()

Encoding Images

The encode_images() function is defined to convert a batch of images into their corresponding embeddings. The function takes a list of image file paths or PIL Image objects, as well as a batch size.

Inside this function, transform_fn() is defined as a helper function to transform the images using the CLIP pre-processing. A Hugging Face Dataset is created from the list of images, and the transform function is set to this dataset.

A DataLoader is created to handle batching of data, and a for loop is used to iterate through this DataLoader. For each batch, the images are fed through the CLIP model to obtain embeddings, which are detached from the computational graph and moved to CPU memory (as numpy arrays) for storage.

The progress of the encoding process is tracked with a tqdm progress bar. The function returns a numpy array that stacks all the image embeddings together.

The encode_images() function is called with the list of image paths and a batch size of 32, and the resulting embeddings are stored in the vector_embedding variable.

def encode_images(images: Union[List[str], List[PIL.Image.Image]], batch_size: int):
def transform_fn(el):
if isinstance(el['image'], PIL.Image.Image):
imgs = el['image']
else:
imgs = [Image().decode_example(_) for _ in el['image']]
return preprocess(images=imgs, return_tensors='pt')

dataset = Dataset.from_dict({'image': images})
dataset = dataset.cast_column('image',Image(decode=False)) if isinstance(images[0], str) else dataset
dataset.set_format('torch')
dataset.set_transform(transform_fn)
dataloader = DataLoader(dataset, batch_size=batch_size)
image_embeddings = []
pbar = tqdm(total=len(images) // batch_size, position=0)
with torch.no_grad():
for batch in dataloader:
batch = {k:v.to(device) for k,v in batch.items()}
image_embeddings.extend(model.get_image_features(**batch).detach().cpu().numpy())
pbar.update(1)
pbar.close()
return np.stack(image_embeddings)

vector_embedding = np.array(encode_images(image_path,32))

Store Embeddings

Make sure to store the embeddings so that we don’t have to generate embeddings again.

with open('flicker30k_image_embeddings.pkl','wb') as f:
pickle.dump(vector_embedding, f)

Indexing the Image

The image embeddings generate can be indexed using Facebook's FAISS for efficient searching.

index = faiss.IndexFlatIP(vector_embedding.shape[1])
index.add(vector_embedding)

Query Encoding

The user query that’s done is natural language needs to be converted into embeddings so the our model is able to understand the context and map it to the required Image.

Let’s make a query for basketball game and see what it returns.

def encode_text( text: List[str], batch_size: int):
device = "cpu"
dataset = Dataset.from_dict({'text': text})
dataset = dataset.map(lambda el: preprocess(text=el['text'], return_tensors="pt",
max_length=77, padding="max_length", truncation=True),
batched=True,
remove_columns=['text'])
dataset.set_format('torch')
dataloader = DataLoader(dataset, batch_size=batch_size)
text_embeddings = []
pbar = tqdm(total=len(text) // batch_size, position=0)
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
text_embeddings.extend(model.get_text_features(**batch).detach().cpu().numpy())
pbar.update(1)
pbar.close()
return np.stack(text_embeddings)


search_text = "basketball game"
with torch.no_grad():
text_search_embedding = encode_text([search_text], batch_size=32)
text_search_embedding = text_search_embedding/np.linalg.norm(text_search_embedding, ord=2, axis=-1, keepdims=True)

Search

Now that we have image embeddings indexed as well as query text embedded, we are ready to search across the index to get the top k result. Here, I’m retrieving top 2 images but you can adjust it as per your requirement.

distances, indices = index.search(text_search_embedding.reshape(1, -1), 2) #2 represent top n results required
distances = distances[0]
indices = indices[0]

indices_distances = list(zip(indices, distances))
indices_distances.sort(key=lambda x: x[1]) # Sort based on the distances

Visualizing the Result

Let’s visualize our output images.

from PIL import Image
for idx, distance in indices_distances:
path = image_path[idx]
im = Image.open(path)
plt.imshow(im)
plt.show()

And that’s it. In a few simple steps, we’re able to build our own semantic image search engine!

Do checkout my github repository for the notebook!

Follow our Social Accounts- Facebook/Instagram/Linkedin/Twitter

Join AImonks Youtube Channel to get interesting videos.

--

--