Semantic Search: Hands On Text to Image Search Using CLIP and FAISS
Semantic search revolutionizes the process of information retrieval by going beyond the mere matching of keywords, focusing instead on deciphering both the searcher’s intent and the contextual meaning of the terms used in the query within the digital data space. This approach, while considering the presence of keywords, delves deeper to interpret context, detect synonyms and homonyms, recognize hierarchies, and most critically, comprehend the intent behind the words in the query.
In a digital landscape that’s teeming with visual content, it can be overwhelming to navigate and locate the exact images you need. This is where the magic of semantic search becomes particularly invaluable. It provides a robust and smart strategy to sift through massive volumes of images efficiently and accurately, enabling you to uncover the precise visual content that aligns with your search intent.
In this article, we’ll embark on a fascinating journey that will help you understand and implement a straightforward yet remarkably effective method for performing semantic searches within large collections of images.
For this exercise, we’ll be using CLIP model from OpenAI to build the embeddings and Facebook’s FAISS library for Indexing. We’d also be using Flicker 30k Dataset available on Kaggle.
If you want to know more about how CLIP works or how embeddings can be used to search in vector space, I have explained them thoroughly in my previous posts.
Imports and Setup
Let's import some packages that we’d be needing for our exercise:
import os
from PIL import Image
import PIL
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import pandas as pd
import pickle
import torch
from datasets import Dataset, Image
from torch.utils.data import DataLoader
from typing import List, Union, Tuple
from transformers import CLIPProcessor, CLIPModel
import faiss
Load the model
Next, load the pre-trained CLIP model and its associated processor and set the device. Here, I’m setting my device as cpu
but if you have cuda
available on your device, you can set it the same.
Let's also get the image path.
device = "cpu" #setting the device
# loading CLIP model and its processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# getting image paths
image_path = ['./flicker30k/Images/' + path for path in image_path if '.jpg' in path]
image_path.sort()
Encoding Images
The encode_images()
function is defined to convert a batch of images into their corresponding embeddings. The function takes a list of image file paths or PIL Image objects, as well as a batch size.
Inside this function, transform_fn()
is defined as a helper function to transform the images using the CLIP pre-processing. A Hugging Face Dataset is created from the list of images, and the transform function is set to this dataset.
A DataLoader is created to handle batching of data, and a for loop is used to iterate through this DataLoader. For each batch, the images are fed through the CLIP model to obtain embeddings, which are detached from the computational graph and moved to CPU memory (as numpy arrays) for storage.
The progress of the encoding process is tracked with a tqdm
progress bar. The function returns a numpy array that stacks all the image embeddings together.
The encode_images()
function is called with the list of image paths and a batch size of 32, and the resulting embeddings are stored in the vector_embedding
variable.
def encode_images(images: Union[List[str], List[PIL.Image.Image]], batch_size: int):
def transform_fn(el):
if isinstance(el['image'], PIL.Image.Image):
imgs = el['image']
else:
imgs = [Image().decode_example(_) for _ in el['image']]
return preprocess(images=imgs, return_tensors='pt')
dataset = Dataset.from_dict({'image': images})
dataset = dataset.cast_column('image',Image(decode=False)) if isinstance(images[0], str) else dataset
dataset.set_format('torch')
dataset.set_transform(transform_fn)
dataloader = DataLoader(dataset, batch_size=batch_size)
image_embeddings = []
pbar = tqdm(total=len(images) // batch_size, position=0)
with torch.no_grad():
for batch in dataloader:
batch = {k:v.to(device) for k,v in batch.items()}
image_embeddings.extend(model.get_image_features(**batch).detach().cpu().numpy())
pbar.update(1)
pbar.close()
return np.stack(image_embeddings)
vector_embedding = np.array(encode_images(image_path,32))
Store Embeddings
Make sure to store the embeddings so that we don’t have to generate embeddings again.
with open('flicker30k_image_embeddings.pkl','wb') as f:
pickle.dump(vector_embedding, f)
Indexing the Image
The image embeddings generate can be indexed using Facebook's FAISS
for efficient searching.
index = faiss.IndexFlatIP(vector_embedding.shape[1])
index.add(vector_embedding)
Query Encoding
The user query that’s done is natural language needs to be converted into embeddings so the our model is able to understand the context and map it to the required Image.
Let’s make a query for basketball game
and see what it returns.
def encode_text( text: List[str], batch_size: int):
device = "cpu"
dataset = Dataset.from_dict({'text': text})
dataset = dataset.map(lambda el: preprocess(text=el['text'], return_tensors="pt",
max_length=77, padding="max_length", truncation=True),
batched=True,
remove_columns=['text'])
dataset.set_format('torch')
dataloader = DataLoader(dataset, batch_size=batch_size)
text_embeddings = []
pbar = tqdm(total=len(text) // batch_size, position=0)
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
text_embeddings.extend(model.get_text_features(**batch).detach().cpu().numpy())
pbar.update(1)
pbar.close()
return np.stack(text_embeddings)
search_text = "basketball game"
with torch.no_grad():
text_search_embedding = encode_text([search_text], batch_size=32)
text_search_embedding = text_search_embedding/np.linalg.norm(text_search_embedding, ord=2, axis=-1, keepdims=True)
Search
Now that we have image embeddings
indexed as well as query text
embedded, we are ready to search across the index to get the top k
result. Here, I’m retrieving top 2
images but you can adjust it as per your requirement.
distances, indices = index.search(text_search_embedding.reshape(1, -1), 2) #2 represent top n results required
distances = distances[0]
indices = indices[0]
indices_distances = list(zip(indices, distances))
indices_distances.sort(key=lambda x: x[1]) # Sort based on the distances
Visualizing the Result
Let’s visualize our output images.
from PIL import Image
for idx, distance in indices_distances:
path = image_path[idx]
im = Image.open(path)
plt.imshow(im)
plt.show()
And that’s it. In a few simple steps, we’re able to build our own semantic image search engine!
Do checkout my github repository for the notebook!
Follow our Social Accounts- Facebook/Instagram/Linkedin/Twitter
Join AImonks Youtube Channel to get interesting videos.