Ch 9. Vision Transformer Part I— Introduction and Fine-Tuning in PyTorch

How using self-attention for image classification reduces inductive bias inherent to CNNs including translation equivariance and locality, thus improving performance compared to ResNets when pre-trained with much larger datasets such as ImageNet-21k

Lucrece (Jahyun) Shin
9 min readJan 28, 2022

*This post’s associated Colab Notebook contains step-by-step code for downloading pre-trained ViT model checkpoints, defining a model instance, and fine-tuning ViT.

One important thing I learned as a machine learning practitioner is that the newest, state-of-the-art research papers’ proposed algorithms might not always magically solve problems that the existing ones could not. Sometimes they work well, sometimes they don’t. No matter what, it’s important to record the process and results of any experiments performed, so I can always refer back to use/improve them for relevant projects in the future.

With that in mind, I’d like to introduce a fairly new image classification model in this post called Vision Transformer and discuss its performance on Xray scanner threat detection data. Although its final optimized version performed poorer than ResNet50 trained with adversarial discriminative domain adaptation (ADDA), it was interesting to work with a whole different vision framework from CNNs.

Here are the list of topics for this post. If you are already familiar with transformers and ViT, you can skip to Vision Transformer Part II — Iterative Erasing of Unattended Image Regions in PyTorch, where I discuss an effective prediction heuristic I invented that uses ViT’s attention weights.

  1. Transformer (Self-Attention)
  2. Vision Transformer (ViT)
  3. ViT Model Fine-Tuning in PyTorch
  4. Brief Intro to Xray Threat Detection Project
  5. ViT — Initial Performance

1. Transformer (Self-Attention)

1.1 Overview

Attention Is All You Need paper (2017) introduced transformer and self-attention, gaining huge attention in natural language processing. It gave birth to successful transformer-based language models such as Google’s BERT (stack of transformer encoders as a language model generator) and Open AI’s GPT-2 (stack of transformer decoders as an autoregressive inference model), which achieved top performance on a wide range of language tasks. I won’t go into technical details of the transformer architecture, but you can review this blog post by Jay Alammar, which I personally think is the best transformer tutorial on the internet.

Transformer’s high-level structure, containing 6 encoders and 6 decoders (Source)

As complicated as it sounds, transformer is just another mechanism that encodes a sequence of input tokens and decodes the encoded representation to match the sequence of target tokens. A standard transformer is composed of 6 encoders and 6 decoders stacked as above, where each encoder and decoder contains self-attention, normalization, and feed-forward components :

More detailed view inside transformer’s encoder and decoder (Source)

1.2 Transformers in Natural Language Processing

Source: http://jalammar.github.io/illustrated-transformer/, Tensor2Tensor notebook

Above illustrates how self-attention is applied in a sentence. Different colours represent different attention heads, which expand the model’s ability to focus on different positions in the sentence. As the word “it” is encoded, one attention head (orange) is focusing most on “the animal”, while another (green) is focusing on “tired”. So the model’s representation of the word “it” will effectively include some representations of both “the animal” and “tired”.

1.3 Transformers in Computer Vision

Although Convolutional Neural Networks (CNNs) are still the most popular building blocks in computer vision, recent research has been carried out for using transformer encoders for encoding images, which I will introduce next.

2. Vision Transformer (ViT)

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper from Google research (2020) introduces a model called Vision Transformer (ViT for short).

[Model Overview] We split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard transformer encoder. In order to perform classification, an extra learnable “classification token” is added to the sequence. (Source: ViT Paper)

2.1 Patching, Flattening, and Embedding

Visualization of 16 by 16 grid of image tokens for which attention to one another will be computed. You can see that each image token covers only a small area of the image.

The standard transformer takes a 1D sequence of token embeddings as input. To handle 2D images, the input image is first divided into N by N patches as above, creating 3(RGB)*N² smaller “image tokens”. The 3D patches (height x width x channels) are then flattened into 1D. Next, since transformer uses constant latent vector size D for all layers, the flattened patches are mapped to D dimensions with a trainable linear projection. The result sequence is called patch embeddings.

2.2 [class] token for Classification

Similar to BERT’s [CLS] token, a learnable classification token embedding is prepended to the sequence of image tokens’ patch embeddings, whose state at the output of the transformer encoder serves as the image representation.

2.3 Classification Head

A classification head is used for classification with a simple MLP architecture, with one hidden layer at pre-training a single linear layer at fine-tuning. Its input is the classification token’s state at the output of the Transformer encoder. Its output is the sequence of logits of M classes.

2.4 Position Embeddings

To retain positional information, the standard 1D position embeddings are added to patch embeddings. The resulting sequence of vectors is passed as input to the transformer encoder.

2.5 Transformer Encoder

As shown in the right side of the ViT model overview figure above, the transformer encoder consists of alternating layers of multi-headed self-attention and MLP blocks. Layer normalization is applied before every block along with residual connections. Transformer decoder is not used in ViT, as ViT’s main objective is image classification for which decoding the encoded image representation is unnecessary.

2.6 Three Model Variants

Three ViT Variants and associated hyperparameters

The paper proposes three different ViT variants as listed in the table above. Here are some key hyperparameters :

  • Layers : Number of alternating pairs of multi-headed self-attention blocks and MLP blocks.
  • Hidden size D : Final feature embedding dimension of ViT encoder
  • Heads : Number of attention heads in each self-attention block.

I used ViT-Base for my experiments.

2.7 ViT vs. ResNet

The paper mentions that ViT lacks some of the inductive bias inherent to CNNs, such as translation equivariance and locality. Thus ViT does not generalize as well as ResNets when trained on insufficient amounts of data. However, it shows better classification performance than ResNets when trained with much larger datasets such as ImageNet-21k containing over 14M images and 21K classes, compared to ~1M images and 1K classes for ImageNet.

Next, I will go into detailed implementation of ViT fine-tuning in PyTorch.

3. ViT Model Fine-Tuning in PyTorch

The entire step-by-step code for ViT fine-tuning is shown in my Colab Notebook. I will list some key code snippets here.

3.1 Loading the Pre-trained ViT

For defining and fine-tuning ViT, I used this Github repo using PyTorch. The model loading procedure is as following.

1. Clone the Github repo and copy all files in the current directory :

2. Download ViT-Base model checkpoint :

Google’s official pre-trained model checkpoints are available here for download. I used :

  • ViT-Base
  • pre-trained on ImageNet-21k and fine-tuned on the smaller ImageNet2012
  • takes in input images of size 224 by 224
  • divides input images into 16 by 16 grid (196 image tokens)

3. Define a ViT-Base model instance and load the downloaded checkpoint :

4. Define ViT encoder and a separate fully-connected classifier :

Here I am only using the ViT model’s encoder part that outputs class-discriminative embeddings of images. I created a separate classifier using torch.nn.Linearthat classifies the embeddings into one of N classes.

3.2 Fine-tuning ViT

Training function for ViT is not much different from transfer learning with CNN-based models such as ResNet50, which includes forward/backward propagation for training stage followed by validation stage. There is a unique way we must handle the outputs of ViT encoder, though. Take a look at the following code inside each training epoch:

  • line 2: Given input images, ViT encoder yields two outputs: 1. final embedding vectors of all input tokens (image tokens + [class] token) and 2. a stack of attention weights for all layers and all heads. I found it useful to look at the attention weights while debugging the model, but they aren’t necessarily needed for training and validation. The dimensions of embedddingstensor is [batch size, n_tokens, embedding_dim] where n_tokens is the total number of input tokens (e.g. if ViT divides input images by16 x 16 grid, n_tokens is 16 times 16 (image tokens) + 1 ([class] token) = 197. embedding_dim is the Hidden size D from section 2.6 which is the final feature embedding dimension of ViT encoder: 768 for ViT-Base, 1024 for ViT-Large, and 1280 for ViT-Huge.
  • line 3: Here we extract only the [class] token’s embedding vector, which is located at index 0, to be used as the image representation vector for classification (as discussed in section 2.2).
  • line 4: Finally, we pass the [class] token’s embedding vector into the classifier to produce logits of length equal to the number of classes.

4. Brief Intro to Xray Threat Detection Project

Here I’ll briefly introduce the project for which I used ViT.

4.1 Given Task

For my masters research at University of Toronto, I worked on developing an Automatic Threat Detection for Airport Xray Baggage Scanner. Given Xray scan images like below, the model had to detect any gun or knife if present.

Samples of the three classes from web (source) and Xray (target) domains

4.2 Datasets

An international airport provided me with 450 Xray baggage scan images with three classes: gun (117 images), knife (33 images), and benign/not harmful (300 images). Since this is far less than enough to train a neural network without overfitting, I collected a large amount (~1300 images per class) of non-Xray, stock photo-like images of the same object classes from the internet to fine-tune ViT.

4.3 Suggested Research Path

My research supervisor advised me to work with image classification objective (classifying the whole image as a class), as opposed to object detection (predicting bounding boxes around objects) or object segmentation (classifying each pixel as belonging to a class or not) to keep the model complexity moderate.

He also suggested me to take domain adaptation approach, since there were not enough Xray images to train a neural network without overfitting. For this I first collected a large amount of non-Xray, stock photo-like images of the same object classes from the internet, used them to train the model, then adapted the model to generalize well on Xray images as well. For more detailed project background, please refer to this project introduction post and the list of my project posts.

5. ViT — Initial Performance

Below table shows the initial results of ViT fine-tuning compared to other models using ResNet50. Source-only means that the model is fine-tuned with web images (source domain) of gun and knife only, and NO Xray images (target domain) are used during fine-tuning. You can read more about domain adaptation and ADDA (Adversarial Discriminative Domain Adaptation) in this post if you’re interested.

Recall table for Xray images comparing various models

ViT performs marginally better than ResNet50, but the recalls are far from the acceptable safety range.

Next Steps

Instead of stopping here, I wanted to get a more intuitive sense on how humans perceive and detect the presence of gun or knife in an image and apply that to the computer. With that in mind I invented a prediction algorithm that iteratively erases unattended regions of image, thus making it easier for the model to focus on the object of interest. I will discuss it in my Vision Transformer Part II post.

--

--