Understanding Vision Transformers: A New Era in Image Recognition with PyTorch

Jaskaran Bhatia
6 min readJul 11, 2023

--

In the ever-evolving landscape of artificial intelligence, the Transformer model has been nothing short of a revelation for natural language processing (NLP). Its unique ability to capture global dependencies in data has led to groundbreaking advancements in tasks like machine translation, text summarization, and sentiment analysis. But here’s a thought — what if we could harness the power of the Transformer architecture for image recognition tasks? This is the intriguing premise behind the Vision Transformer (ViT), a novel approach to image recognition that views an image as a sequence of patches, akin to how a sentence is treated as a sequence of words in NLP.

Transformers are widely used in NLP, after the famous research paper ‘Attention is all you need’ published in 2017. It allowed us to take sequential input and finds the correlation between different features using a concept called self-attention. This concept gained a lot of attraction in Computer Vision from 2020 onwards.

How transformers are used for Images ?

  1. Split the image into patches
  2. Flatten the patches
  3. Create a lower-dimensional linear embeddings from the flattened patches.
  4. Add positional embeddings (trainable position embedding used typically).
  5. Feed the sequence as an input to standard transformer encoder.
  6. Pre-train the model with image labels (fully supervised on a huge dataset).
  7. Fine tune on the downstream dataset for image classification or some other task.

More facts about ViTs

  1. Image patches are equivalent to sequence tokens (words in NLP).
  2. We can vary the number of blocks to allow for deeper networks.
  3. ViTs require a lot more data to be trained to beat SOTA CNNs.
  4. However, we can pre-train on larger dataset and fine-tune on the smaller one (change the MLP head).

Understanding ViTs Mathematically

First each patch is then flattened into a 1D vector, and a linear projection is applied. This process can be represented mathematically as:

Z_0 = E * X + e

Here, X represents the input image patches, E is the patch embedding matrix, e stands for the position embeddings, and Z_0 is the initial sequence of embedded patches.

Then it goes to Transformer encoder which is comprised of a sequence of Transformer layers, each layer having two sub-layers: a multi-head self-attention mechanism and a position-wise feed-forward network. Multi-head self-attention mechanism is given as:

Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V

In this equation, Q, K, and V are the query, key, and value matrices, respectively, and d_k is the dimensionality of the queries and keys.

The position-wise feed-forward network consists of two linear transformations with a ReLU activation in between:

FFN(x) = max(0, x * W1 + b1) * W2 + b2

Here, x is the input, W1, W2, b1, and b2 are the parameters of the feed-forward network.

The output of the Transformer encoder is a sequence of vectors, one for each image patch. The output corresponding to the classification token is used to make the final prediction using a linear layer:

y = softmax(C * z + b)

In this equation, C is the weight matrix of the linear layer, z is the output of the Transformer encoder corresponding to the classification token, b is the bias, and y is the final prediction.

Vision Transformers using PyTorch

For implementation in PyTorch, we will be using the timm (PyTorch Image Models). It’s a library created by Ross Wightman and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, dataloaders and a lot more. I will be using some images in this guide, and those the copyright of figures and demo images belongs to Hirota Honda.

  1. Install the PyTorch image models library
# PyTorch Image Models
!pip install timm

2. Importing necessary libraries and functions

import os
import matplotlib.pyplot as plt
import numpy as np
import PIL

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T

from timm import create_model

3. Load one of the vision transformer models from timm

# Load your model
model_name = "vit_base_patch16_224"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = create_model(model_name, pretrained=True).to(device)

4. Defines some inference parameters and transformations

IMG_SIZE = (224, 224)
NORMALIZE_MEAN = (0.5, 0.5, 0.5)
NORMALIZE_STD = (0.5, 0.5, 0.5)

transforms = [
T.Resize(IMG_SIZE),
T.ToTensor(),
T.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
]

transforms = T.Compose(transforms)

5. As we are doing a classification task we need Image labels and a sample image to perform the task. The name of imagenet_labels file is ilsvrc2012_wordnet_lemmas.txt and can be downloaded as follows:

# ImageNet Labels
!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
imagenet_labels = dict(enumerate(open('ilsvrc2012_wordnet_lemmas.txt')))

6. Load the sample image using Pillow Library

img = PIL.Image.open('sample.png')
img_tensor = transforms(img).unsqueeze(0).to(device)

Below is the sample image, I used for the inferencing:

An image showing a dome like structure in Santorini

7. How to do inferencing using the loaded pre-trained ViT

output = model(img_tensor)
print(f"Inference Result: {imagenet_labels[int(torch.argmax(output))]}")
plt.imshow(img)

The result obtained is as follows:

The inference result obtained is: church, church_building

As you can see the results obtained are quite accurate using these class of models.

Deep Dive into Vision Transformers

Image shows the complete pipeline for a Vision Transformer
  1. Split Image into Patches
    The input image is split into 14 x 14 vectors with dimension of 768 by Conv2d (k=16x16) with stride=(16, 16).
  2. Add Position Embeddings
    Learnable position embedding vectors are added to the patch embedding vectors and fed to the transformer encoder.
  3. Transformer Encoder
    The embedding vectors are encoded by the transformer encoder. The dimension of input and output vectors are the same.
  4. MLP (Classification) Head
    The 0th output from the encoder is fed to the MLP head for classification to output the final classification results.

Understanding Transformer Encoder in ViTs

Image explains the working of 12 series encoders in ViTs
  1. N (=197) embedded vectors are fed to the L (=12) series encoders.

2. The vectors are divided into query, key and value after expanded by an fc layer.

3. q, k and v are further divided into H (=12) and fed to the parallel attention heads.

4. Outputs from attention heads are concatenated to form the vectors whose shape is the same as the encoder input.

5. The vectors go through an fc, a layer norm and an MLP block that has two fc layers.

Visualization for Position Embeddings

To make patches position-aware, learnable ‘position embedding’ vectors are added to the patch embedding vectors. The position embedding vectors learn distance within the image thus neighboring ones have high similarity.

If we show one cell/patch of image’s positional embedding with it’s cosine similarity with all the other embeddings, we will get something like given below:

Cosine similarity between an positional embedding and all the other positional embeddings

The future of Vision Transformers

The Vision Transformer (ViT) represents a significant step in applying Transformer models to computer vision tasks. It opens up many possibilities for future research and has the potential to revolutionize the field of image recognition. As the saying goes, “An image is worth a thousand words”, but in the case of ViT, “An image is worth 16x16 words”. The future of image recognition looks bright, and I am excited to see where this journey takes us.

--

--

Jaskaran Bhatia

ML Engineer @ J-Squared Technologies | MScAC @ University of Toronto | Ex-JPMorgan