Sitemap

Vision Transformer (ViT)

5 min readOct 18, 2024

Transformers have already revolutionized the world of NLP and now they have also transitioned into the realm of computer vision, giving rise to Vision Transformer (ViT), which can handle the task of image processing.

In traditional image processing tasks, Convolutional Neural Networks (CNNs) were the go-to method. They excelled at detecting patterns and features through a grid-like structure, making them great for images, but CNNs struggled to model global relationships between different parts of an image.

See, how things are changing I am calling CNNs as traditional, just few years back, when chatGPT didn’t came into picture, CNNs were the only SOTA, I used to know and fear it. 😅😅😅

To address these limitations, Transformers, originally designed for text, were adapted for computer vision tasks. Unlike CNNs, Transformers can model long-range dependencies in data, which is essential for both NLP and image understanding. Transformers were designed for sequential data (like sentences). But images are structured as a grid of pixels. The main task was to “convert” an image into a format the transformer could process while leveraging the powerful self-attention mechanism of Transformers.

To know more about self-attention mechanism of Transformers read this : https://medium.com/@RobuRishabh/how-transformers-work-b08627a300cb

The idea behind Vision Transformers (ViT) is to break down an image into patches (smaller blocks), treat each patch like a token (similar to words in a sentence), and then feed these tokens into a Transformer model.

Key Concepts in ViT:

Press enter or click to view image in full size
Source : 2010.11929v2.pdf (arxiv.org)

Patchify the Image:

Split the image into smaller patches. Imagine an image of size 128x128 pixels. ViT splits this image into smaller patches, say 16x16 patches. For a 128x128 image, this would create 64 patches (128 ÷ 16 = 8 patches along width, and 8 patches along height, making 8x8 = 64 patches).

Linear Projection:

Flatten each patch and project it into a vector (token). These patches are flattened into a vector, like turning a mini image into a word in a sentence.

Positional Encoding:

Just as words in a sentence need to be understood in order (the word “car” and “fast” have different meanings based on their order), patches in an image need their position to be understood. Positional encoding is added to the patch tokens to inform the Transformer of where each patch is located in the original image.

Transformer Encoder:

Process the sequence of image tokens through Transformer layers. Each patch, now represented as a token, passes through the Transformer encoder. The self-attention mechanism helps each patch learn about other patches in the image. For example, the patch showing the wheel of a car will learn information about the patch showing the car’s body.

Classification Token:

After processing the image tokens, ViT uses a special classification token. This token aggregates the global information from all patches and is passed through a simple neural network to make the final classification (e.g., identifying the object in the image).

Importing Pre-trained ViT and using it for classification in PyTorch

Step 1: Install Required Libraries

To use Hugging Face’s ViT model, you’ll need the transformers, datasets, and torchvision libraries.

pip install transformers datasets torchvision torch

Step 2: Import Pre-trained ViT and Setup

import torch
from transformers import ViTForImageClassification,ViTFeatureExtractor,Trainer,
TrainingArguments
from datasets import load_dataset
from torchvision.transfroms import Compose, Resize, ToTensor, Normalize
from torch.utils.data import Dataloader

#Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained ViT model from Hugging face
model = ViTForImageClasssification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)
model.to(device)

# Load ViT feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

Step 3: Load and Prepare Dataset

We’ll use the CIFAR-10 dataset for image classification. The dataset will be normalized, resized, and converted into the format expected by ViT.

# Load CIFAR-10 dataset (replace with your dataset if needed)
dataset = load_dataset("cifar10")

# Define image transformation
def transform(example_batch):
transforms = Compose([
Resize((224, 224)), # ViT expects 224x224 images
ToTensor(),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) # Apply ViT specific normalization
])
example_batch['pixel_values'] = [transforms(img) for img in example_batch['img']]
return example_batch

# Apply the transformation to the dataset
prepared_dataset = dataset.with_transform(transform)

# Create PyTorch Dataloaders
train_loader = DataLoader(prepared_dataset['train'], batch_size=32, shuffle=True)
test_loader = DataLoader(prepared_dataset['test'], batch_size=32)

Step 4: Fine-tune the Pre-trained ViT Model

We can now fine-tune the pre-trained Vision Transformer on our dataset using the Hugging Face Trainer API, which simplifies the training loop. Using the Trainer API, we fine-tune the pre-trained ViT model on the CIFAR-10 dataset. The training configuration, including the number of epochs, batch size, and evaluation strategy, is specified via TrainingArguments.

# Define training arguments
training_args = TrainingArguments(
output_dir='./results', # output directory
per_device_train_batch_size=32, # batch size for training
per_device_eval_batch_size=32, # batch size for evaluation
num_train_epochs=3, # number of training epochs
evaluation_strategy="epoch", # evaluate after each epoch
save_strategy="epoch", # save checkpoint after each epoch
logging_dir='./logs', # logging directory
logging_steps=10,
load_best_model_at_end=True, # Load best model after training
)

# Define a simple evaluation metric
from sklearn.metrics import accuracy_score
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
acc = accuracy_score(labels, preds)
return {"accuracy": acc}

# Define Trainer with model, data, and training arguments
trainer = Trainer(
model=model,
args=training_args,
train_dataset=prepared_dataset["train"],
eval_dataset=prepared_dataset["test"],
tokenizer=feature_extractor,
compute_metrics=compute_metrics
)
# Fine-tune the model
trainer.train()

Step 5: Use the Fine-tuned Model for Inference

Once the model is fine-tuned, you can use it for making predictions.

# Example of inference on a single image
from PIL import Image
# Load an image from the test set
image = Image.open('path_to_image.jpg')
# Apply the feature extractor's transformations
inputs = feature_extractor(images=image, return_tensors="pt").to(device)
# Use the fine-tuned model to predict
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# Print the predicted class
print(f"Predicted class: {predicted_class_idx}")

References:

If you liked this breakdown of concepts please follow, subscribe and clap

Press enter or click to view image in full size

--

--

Rishabh Singh
Rishabh Singh

Written by Rishabh Singh

I'm a passionate AI/Robotics enthusiast. Dedicated to pushing the boundaries of programming, robotics/AI contributing to innovative advancements in the field.

No responses yet