ViT — An Image is worth 16x16 words: Transformers for Image Recognition at scale — ICLR’21

Momal Ijaz
AIGuys
Published in
7 min readJan 30, 2022

This article is the first paper of the “Transformers in Vision” series, which comprises summaries of the recent advanced papers, submitted in the range of 2020–2022, to top conferences, focusing on transformers in vision.

*NerdFacts-🤓 have additional intricate details, which you can skip and still be able to get a high-level flow of paper!

Transformers

Transformers were invented by the Google research team, in the late 2017s, for performing neural machine translation. They proved to be a blast, not just because of their exceptional performance for neural machine translation, but because of their good generalization performance for almost any NLP task including QnA, text generation, search, and more! Transformers use self-attention to capture long-range dependencies for developing a stronger semantic understanding of human language. For a complete summary of Transformers, check this out.

Dramatic Re-enactment of Transformers entry in to NMT world

Vision Transformer (ViT)

After the blooming success of transformers in NLP, researchers started applying them in the vision domain too, where for high-level tasks like object detection, segmentation, classification still CNN based variants are dominant. Google brain’s research team jumped in again and published a paper called Vision Transformers, which you are here for reading a summary of.

ViT, didn't give satisfactory results when they were trained on smaller datasets, but outperformed SOTA for object classification, by a few percentage points, when trained on large datasets. Specifically, ViTs were pretty good, when pre-trained on large datasets, and then finetuned on smaller datasets. Pretrained ViTs outperformed EfficientNet and ResNet-based SOTA networks on datasets including ImageNet, Image-Net Real, CIFAR-100, and VTAB-19.

1. Model Architecture

Vision Transformers Architecture — Src ViT paper

The architecture of ViTs is very simple and intentionally kept as close as possible to the original transformers. Below is a stepwise analysis of each part of the model architecture.

  1. First, the input image is divided into N patches, where each patch is of size PxP and has P² pixels. Each patch is treated as a token/word in a sentence.[NerdFact-🤓: Size of a patch is inversely proportional to the number of patches, hence smaller patches make the model computationally expensive].
Crop image into N patches

2. Flatten the patches and pass them through a linear layer to project them to D- sized embedding. This linear layer is of size P² x d.

Pass the embeddings through the Linear layer and add pos encodings

3. Add positional encodings to the flattened embeddings. Positional encodings are vectors of dimension D, hence can be added to the embeddings vectors. As the image is divided into patches, the spatial information is lost, part of which can be restored by adding positional encodings to the patches. P1-P9 are positional encodings in the above image. [NerdFact-🤓: This is the only part, where the spatial location information is added to inputs, authors expected for all other spatial relations to be learned from scratch by attention!].

Learnable Class token

4. An additional embedding vector, of the same dimension as the patch embedding dimension, is added to the network. It is initialized with random values and is learnable. This specific class token is the only output that is used at the output of the network and is used for performing classification. [NerdFact-🤓: This choice of additional token addition for classification, is inspired from BERT’s sequence classification task settings].

Passing processed patches through encoder stack

5. Finally, these pre-processed, D-dim patches are entered into the first transformer’s encoder. (See this article if you need info about encoder). It is not a single encoder but a stack of encoders, the input and output size remains intact after passing through an encoder. Encoder stack uses self-attention to encode these patches with stronger semantic meaning and correlations.

Multilayer Perceptron Classification head

6. Finally the output from the encoder is a D-dim vector of all the input patch embeddings, but they are all discarded except for the output for the classification token head, which is then passed through a simple neural network, with hidden layers. The count and size of these hidden layers vary in different versions of ViT. This MLP’s input layer has D nodes and the output layer has C nodes, where C is the number of classes. The C scores from the output of MLP are passed through softmax to get a probability distribution for a given sample over all classes.

2. Hybrid Architecture

For experimentation with network architecture, a hybrid architecture was also proposed and tested in the paper. In this variant of the ViT, instead of passing the original image patches, authors made patches of a feature map from an intermediate layer of a good CNN, as just the pixel values or colors might not have many meaningful correlations to be extracted by self-attention.

[NerdFact- 🤓: Authors used a patch size of 1x1 pixels for the feature map-based input, as in simple flattening of the input].

3. Experiments

For experimentation, the authors of ViT had three models:

  1. ViT — Original ViT with image patches as input
  2. Hybrid — ViT with feature patches as input from stage3/4 of a ResNet50
  3. Baseline CNN — A vanilla ResNet

Authors pre-trained these models on some larger datasets [Sec-4.1 in paper], removed their heads, and added zero-initialized new classification heads for finetuning and analyzing their performance on smaller downstream datasets. This kinda analysis focuses on judging the model’s representation learning capability.

4. Results

Comparison of this setup with state of the art models gave authors the following results:

ViT results for comparison with SOTA models on benchmark datasets

ViT had three different size variants, ViTH/14 is the biggest model with 16 attention heads, 632M parameters, and an input patch size of 14x14. ViTL/16 is the large ViT with a 16x16 patch size and 307M params. (See Table1, in the paper for details).

As can be seen from the above table, ViT-H/14 was able to outperform SOTA RNN on all datasets, when pre-trained on JFT-300M dataset, (the green boundary in image). Also, ViT took fewer computational resources to pre-train as compared to the ResNet, highlighted in the circle in the above image. The name of the dataset on top of the model name in the first row is the pre-training dataset like JFT, and datasets in the first column are the ones used for finetuning like VTAB, etc.

[NerdFact-🤓: Authors used early stopping and reported best validation accuracy obtained during training.]

Additional Analysis and Found Patterns!

  1. Authors found that large ViT variants only are useful when trained on larger datasets, on smaller datasets ResNet based model outperformed ViT. This reinforces the idea that convolutional inductive biases are helpful for smaller datasets, but for larger models, it’s beneficial to learn spatial relations and all other patterns from data.
  2. ViTs take 2–4x less compute resources for pre-training on larger datasets as compared to the ResNets.
  3. Hybrid models added some good to accuracy at lower architectures but for bigger architectures, the contribution vanished.
  4. ViTs can take an arbitrary length of inputs, but in the case of transfer learning the pre-trained positional embeddings would become useless.
  5. Because of self-attention, each patch is able to bake in information from all other patches, in its encodings. The max distance up to which this information is integrated across images is called attention distance. The authors observed attention distance decreases with the depth of the model. [I think: This implies that as the features pass through the encoder’s stack, they start focusing more on local information].

Conclusion

ViTs came out in early 2021, for exploring the potential of Transformer architectures in the vision domain. This work came up with a simple architecture and tested it extensively in different setups to make ViTs strengths and weaknesses clear. It is data-hungry and needs large datasets, with controlled pretraining and training environments, to unlock its full potential and make it outperform ResNets.

Why do we want to bring in Transformers from NLP to the Vision domain?

Is Transformers based model, gonna be the next generalized flexible and efficient model for vision too?

Is there another perfect architecture like Transformers lying out there, for the Vision domain, but it’s not this Transformer?

Does cross and self-attention computation of words in our speech and pixels in the image, needs to be done with the same scaled dot product attention?

Do we need to come up with more complex Transformer architectures, or analyze the convolutions and attention-based models from an explainable AI perspective, to be able to understand fully why each one is good or bad at their tasks?

…Happy Learning :-)

--

--

Momal Ijaz
AIGuys
Writer for

Machine Learning Engineer @ Super.ai | ML Reseacher | Fulbright scholar'22 | Sitar Player