Image Classification using Few-Shot Learning
In recent years, deep learning-based models have excelled in tasks like object detection and image recognition. On challenging image classification datasets like ImageNet, which contains 1,000 different object classifications, several of these models are capable of performing at a human level. These models, however, rely on the supervised training paradigm, and the availability of labeled training data has a significant impact on how well they function. Additionally, the classes that the models can detect are restricted to those that they were trained on.
Due to the fact that there could not be enough tagged images for all classes during training, these models are less useful in real-world circumstances. Additionally, we want our model to recognize images from classes it was not exposed to during training because it is almost impossible to train on photographs of all potential objects. The problem of learning from a few examples is called Few-Shot learning.
What is Few-Shot learning?
Few-shot learning is a sub-area of machine learning. It involves categorizing new data when there are only a few training samples with supervised data. With only a small number of training examples, a computer vision model can perform pretty well.
Consider the following scenario: We are in the medical field and have trouble classifying bone diseases from x-ray images. There might not be enough photos available for some uncommon disorders to be included in the training set. Building a Few-Shot Learning Classifier is the perfect solution for a scenario like this.
Variations of Few-Shot Learning
In general, researchers identify four types:
- N-Shot Learning (NSL)
- Few-Shot Learning ( FSL )
- One-Shot Learning (OSL)
- Less than one or Zero-Shot Learning (ZSL)
When we’re talking about FSL, we usually mean N-way-K-Shot classification. N stands for the number of classes, and K for the number of samples from each class to train on.
N-Shot Learning is seen as a more broad concept than all the others. It means that Few-Shot, One-Shot, and Zero-Shot Learning are sub-fields of NSL. Zero-shot learning aims to classify unseen classes without any training examples.
In One-Shot Learning, we only have a single sample of each class. Few-Shot has two to five samples per class, making it just a more flexible version of OSL.
Few-Shot learning approaches
Generally, there are two approaches that you should consider when solving Few Shot Learning problems:
- Data-level approach (DLA)
- Parameter-level approach (PLA)
Data-level approach
This strategy is relatively straightforward. It is built on the idea that you should add more data if you don’t have enough to create a solid model and prevent underfitting and overfitting. Because of this, many FSL issues are resolved by utilizing more data from a sizable base dataset. The primary dataset’s distinguishing characteristic is that it lacks the classes that make up our support set for the Few-Shot challenge. The base dataset may contain pictures of numerous other birds if, for instance, we want to categorize a particular bird species.
Parameter-level approach
Few-shot learning samples are relatively simple to overfit from a parameter-level perspective since they frequently have large, high-dimensional spaces. Limiting the parameter space, using regularization, and using the appropriate loss functions will help solve this issue. The model will generalize the small number of training samples.
On the other hand, by guiding the model to the vast parameter space, we can improve performance. Due to the lack of training data, a standard optimization approach may not produce accurate results.
For this reason, we train our model to discover the best path through the parameter space in order to produce the best possible prediction outcomes. This method is known as meta-learning.
Algorithms for Few-Shot image classification
Meta-Learning algorithms that can be used to solve Few-Shot Learning image classification problems.
- Model-Agnostic Meta-Learning (MAML)
- Matching Networks
- Prototypical Networks
- Relation Network
Model-Agnostic Meta-Learning
The Gradient-Based Meta-Learning (GBML) principle is the foundation of MAML. In GBML, prior experience is gained by the meta-learner through base-model training and learning the shared features of all task representations. Every time there is a new task to learn, the meta-learner will be slightly fine-tuned utilizing its existing experience and the minimal quantity of new training data given by the new task.
However, we don’t want to initialize the parameters at random. If we follow this path, after a few updates, our algorithm will not converge to good performance. MAML seeks to address this issue. With only a few gradient steps and without overfitting, MAML offers a solid initialization of meta-parameter learners to achieve optimal fast learning on a new task.
Steps :
- The meta-learner creates a copy of itself © at the beginning of each episode,
- C is trained on the episode (just as we have previously discussed, with the help of base-model),
- C makes predictions on the query set,
- The loss computed from these predictions is used to update C,
- This continues until you’ve trained on all episodes.
The most significant advantage of this technique is that it’s conceived to be agnostic of the meta-learner algorithm choice. Thus, the MAML method is widely used with many machine learning algorithms that need fast adaptation, especially Deep Neural Networks.
Matching Networks
The first Metric-Learning method created to address FSL issues was Matching Networks (MN).
When using the Matching Networks approach to resolve a Few-Shot Learning job, a big base dataset is required. This dataset is divided into episodes, as was already mentioned.
After that, for each episode, Matching Networks apply the following procedure:
- Each image from the support and the query set is fed to a CNN that outputs embeddings for them,
- Each query image is classified using the softmax of the cosine distance from its embeddings to the support-set embeddings,
- The Cross-Entropy Loss on the resulting classification is backpropagated through the CNN.
Matching Networks can learn to construct picture embeddings in this way. MN can categorize photographs using this method without having any special prior knowledge of classes. Simply comparing several instances of the classes is used for everything.
Since the classes vary from episode to episode, Matching Networks compute picture attributes important for class distinction. In contrast, when using a standard classification, the algorithm picks up on the characteristics that are unique to each class.
Prototypical Networks
Similar to matching networks are prototypical networks (PN). Even yet, there are minor variations that serve to improve the performance of the algorithm. In actuality, PN achieves superior outcomes over MN. The PN procedure is essentially the same, except that some of the query picture embeddings from the support set are compared. Prototypical Networks provide a different strategy instead.
You must create class prototypes in PN. Essentially, they are class embeddings created by averaging the embeddings of the images in this class. Then, just these class prototypes are used to compare the query image embeddings. It’s important to note that the technique is comparable to Matching Networks when used for One-Shot Learning problems.
Relation Network
The Relation Network was created as a result of all research done to create Matching and Prototypical Networks (RN). RN was based on the PN idea but included significant algorithmic improvements.
The method learned the distance function rather than having it defined beforehand. This is done via the RN’s relation module. The overall structure is as follows. The relation module is put on top of the embedding module, which is the part that computes embeddings and class prototypes from input images.
The relation module is fed with the concatenation of the embedding of a query image with each class prototype, and it outputs a relation score for each couple. Applying a Softmax to the relation scores, we get a prediction.
Zero-Shot Learning With Open-AI Clip
CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on various (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similar to the zero-shot capabilities of GPT-2 and 3.
CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several significant challenges in computer vision.
Installing Libraries
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
import numpy as np
import torch
from pkg_resources import packaging
print("Torch version:", torch.__version__)
Loading the model
import clip
clip.available_models() # it will list the names of available CLIP models
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
Image Preprocessing
Setting up input images and texts
We will feed 8 example images and their textual descriptions to the model and compare the similarity between the corresponding features.
The tokenizer is case-insensitive, and we can freely give suitable textual descriptions.
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# images in skimage to use and their textual descriptions
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}
original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}\n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
plt.tight_layout()
It will give you a plot like this -
Building Features
We normalize the images, tokenize each text input, and run the forward pass of the model to get the image and text features.
image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
Calculating cosine similarity
We normalize the features and calculate the dot product of each pair.
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])
plt.title("Cosine similarity between text and image features", size=20)
Zero-Shot Image Classification
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
plt.figure(figsize=(16, 16))
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace=0.5)
plt.show()
Conclusion
We’ve delved into the core principles that make few-shot learning an exciting frontier for researchers and practitioners alike. From exploring novel algorithms to practical applications that could revolutionize industries, we stand on the cusp of a new horizon in AI capabilities.
The potential of few-shot learning extends beyond mere classification tasks; it paves the way for building more adaptable, efficient, and accessible AI systems that can learn and evolve with minimal intervention. As our computational resources grow and our algorithms become ever more refined, the promise of few-shot learning looms large, offering a glimpse into a future where AI can quickly adapt to new challenges, much like a human learner.
While this post illuminates the path few-shot learning has carved in image classification, the journey is far from over. The landscape of AI continues to evolve, and with it, the techniques and tools at our disposal. For those captivated by the potential of few-shot learning, the following steps may involve diving deeper into the methodology, experimenting with different models, and innovating in ways we have yet to imagine.
As we conclude, remember that the power of few-shot learning isn’t just in the complexity or the sophistication of the models we build but in the new realms of possibility they unlock. It is an exciting time to be at the forefront of AI research, where each small step taken is a leap towards the vast potential of what our intelligent systems can achieve.
Thank you for joining me in exploring few-shot learning in image classification. Stay tuned for more insights and explorations into the dynamic world of machine learning.
If you found this helpful article, please consider citing us:
@article{jadon2023image,
title={Image Classification Using Few-Shot Learning},
author={Jadon, Aryan},
journal={Medium},
year={2023},
url={https://medium.com/@aryanjadon/image-classification-using-few-shot-learning-286572222b2d}
}
References
- Image1 — from Research Paper CONCEPT LEARNERS FOR FEW-SHOT LEARNING
- Image2 — from Research Paper Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
- Image3 — from Research Paper Learning to Compare: Relation Network for Few-Shot Learning
- Research Paper — Learning from Few Examples: A Summary of Approaches to Few-Shot Learning
- Hands-On One-shot Learning with Python: Learn to implement fast and accurate deep learning models with fewer training samples using PyTorch
- https://github.com/openai/CLIP
- https://neptune.ai/blog/understanding-few-shot-learning-in-computer-vision
- https://deepai.org/publication/meta-learning-algorithms-for-few-shot-computer-vision#S2.SS3
- https://blog.floydhub.com/n-shot-learning/
- https://www.sicara.fr/blog-technique/2019-07-30-image-classification-few-shot-meta-learning