An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

Souvik Mandal
9 min readFeb 20, 2022

--

ViT architecture
ViT architecture presented in the paper

This is a paper from google research. The main idea of this paper is

Transformers applied directly to image patches and pre-trained on large datasets work really well on image classification.

In this post, we will discuss in detail Vision Transformers (ViT ) architecture and the results published in the paper.

Try out the implementations used here by using this notebook. Click on Open in Colab to run directly on Colab.

ViT architecture πŸƒ

  • ViT splits an image into a fixed number of patches and uses them to create embeddings and pass them through a standard transformer encoder.
  • Let’s discuss how the NLP transformer works and then compare and discuss it with ViT.

NLP Transformer Encoder 🎨

NLP Transformer encoder
Transformer encoder
  • We take a sentence as an input. But, instead of sending the whole sentence, we use a tokenizer that gives each word an id. Now a tokenizer does not actually have to split in a naive way like split by word but it can split a word into parts and assign also. This depends on how the tokenizer is trained. For example, a simple tokenizer can do below

But another one split the sentence like below:

Now we have a matrix that maps all the possible ids to the vector representation. So, if we use the second tokenizer if we want to select the embedding for To we will select the vector representation at index 11.

These embeddings can be randomly initialized at the start and we learn during training.

Positional embeddings

  • Recurrent Neural Networks (RNNs) parse a sentence word by word in a sequential manner. But the Transformer architecture does not use the recurrence mechanism in favour of the multi-head self-attention mechanism. This reduces the training time in transformers but the model has no idea about the position of the words.
  • To solve this issue, we add an extra piece of information (positional encodings) to the input embeddings.
  • Now one easy way can be is just to assign 1 to the first word, 2 to second, and so on. But in this approach, the model during inference might get a sentence that is longer than any it saw during training. Also for a longer sentence, there will be large values to add which takes more memory.
  • We can take a range then like add 0 for first work and 1 for last, anything in between we split the range [0,1] and get the values. For example, for a 3-word sentence we can do 0 for the first word, 0.5 for the second, and 1 for the third; for a 4-word sentence, it would be 0,0.33, 0.66, 1 respectively. The problem with this is that the position difference delta is not constant. In the first example, it was 0.5 but in the second case, it was 0.33.
  • The positional encoding used is a d-dimensional vector.
  • So this is how all things combined
  • We will pass these vectors through multi-head attention blocks next.

Multi-head attention πŸ”₯

  • Multi-head attention has three matrices, which are query(Q), key(K), and value(V) matrix. Each of them has the same dimensions as embedding. So, in our case, all 3 matrices are 512x512
  • For each token embedding, we multiply that with all three matrices(Q, K, V). So we will have 3 intermediate vectors of length 512 for each token.
  • Now if we have n heads, we divide each of that vectors into n parts. For example, if we have 8 heads, for the word Today we will divide all 3 intermediate vectors into small vectors with dimension 64.
  • Then each head takes its corresponding segment from all the intermediate vectors. For example, the first head will take the first split (of dimension 64) of all three intermediate vectors (corresponding to the Query, Key, Value multiplication results) of all the five embeddings (corresponding to five tokens). Similarly second head will take the second segment and so on.
  • In each head we dot product between query and key matrix multiplied vectors. In the image below for head 1, we do dot product between q1 and all the key matrix multiplied vector (k{i}, i in [1,5]). We then multiply it by the corresponding value vector. Finally, we add them to create a result 64 dimension vector. This happens for q2, q3, q4, q5 and finally, we get 5 vectors with dimension 64. Now basically each result vector has information about all other vectors.
Attention logic
  • Now we concatenate the result vectors from all the heads. So we will concatenate the first result vector from all 8 heads to create the 512 dim first vector. The same happens for all other 4 vectors.
  • So finally we have 5 vectors each with 512 dimensions.

Add & Norm ☯

  • These are normal batch normalization and residual connections like Resnet block.

Feedforward πŸ€

  • These are simple feed-forward Neural Network that is applied to every attention vector.

So this is how a simple transformers encoder works. Let's see the ViT architecture next. We will also see how to implement it at the same time.

ViT encoder architecture

Embedded patches β˜‘οΈ

  • To handle 2D images, the image is divided into several patches. and we flatten that 2D patches to 1D vectors.
  • We then embed each of these vectors into model dimension space. In this case, the model converts each vector to 768 dimension vectors.
import torch
import torch.nn as nn
in_chans = 3 #RGB
embed_dim = 768 # vector dimension in model space
patch_size = 16 # each image patch size 16*16
proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # this will create the patch in image
img = torch.randn(1, 3, 224,224) # dummy image
x = proj(img).flatten(2).transpose(1, 2) # BCHW -> BNC
print(x.shape)
  • In the above code, we have taken images of size 224*224 and assumed each patch is of size 16x16.
  • Now this will result in total (224/16 * 224/16 ) = 14*14 = 196 vectors.
  • Each of these vectors of size 16*16 = 256. But, because we have to convert it to model dimension, which is 768, we use 768 as output channels in the convolution. Finally, we flatten it to BNC where B= batch, N= resulting patches, C = vector dimension in model space.

Class embeddings πŸ†•

  • ViT appends a learnable embedding to the sequence of embedded patches.
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # create class embeddings without batch
cls_token = cls_token.expand(x.shape[0], -1, -1) # add batch
x = torch.cat((cls_token, x), dim=1) # append class token with linear proj embeddings
x.shape # 196 -> 197

positional embeddings β˜‘οΈ

  • We create a matrix of dimension (num_patches+1) * embed_dim (197*768). The values are learned during training.
num_patches = 14*14
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # +1 for class token
x = x + pos_embed # add position encoding
x.shape

Blocks β˜‘οΈ

  • Based on the model we have n number of blocks.
  • Each block is the same. Each one consist of an attention layer and an MLP layer.

Attention layer β˜‘οΈ

  • Same as NLP attention explained before.
  • Let's create the intermediate vectors.
# Transformation from source vector to query vector
fc_q = nn.Linear(embed_dim, embed_dim)
# Transformation from source vector to key vector
fc_k = nn.Linear(embed_dim, embed_dim)
# Transformation from source vector to value vector
fc_v = nn.Linear(embed_dim, embed_dim)
Q = fc_q(x)
K = fc_k(x)
V = fc_v(x)
print(Q.shape, K.shape, V.shape)
  • Split the intermediate vectors to process one part in each head.
num_heads = 8
batch_size = 1
Q = Q.view(batch_size, -1, num_heads, embed_dim//num_heads).permute(0, 2, 1, 3) # split the Q matrix for 8 head
K = K.view(batch_size, -1, num_heads, embed_dim//num_heads).permute(0, 2, 1, 3) # split the K matrix for 8 head
V = V.view(batch_size, -1, num_heads, embed_dim//num_heads).permute(0, 2, 1, 3) # split the V matrix for 8 head
print(Q.shape, K.shape, V.shape) # batch_size, num_head, num_patch+1, feature_vec dim per head
  • Attention matrix multiplication.
score = torch.matmul(Q, K.permute(0, 1, 3, 2)) # Q*k
score = torch.softmax(score, dim=-1)
score = torch.matmul(score, V) # normally we apply dropout layer before this
score.shape # batch_size, num_head, num_patches+1, feature_vector_per_head (embed_dim/num_head)
  • Reshape the results
score = score.permute(0, 2, 1, 3).contiguous()
score.shape # batch_size, num_patches+1, num_head, feature_vector_per_head (embed_dim/num_head)
  • Merge the vectors back to their original shape
score = score.view(batch_size, -1, embed_dim) # merge the vectors back to original shape
score.shape # batch_size, num_patches+1, embed_dim

MLP Head β˜‘οΈ

  • Normal multiple layer perceptron.
act_layer=nn.GELU # activation function
in_features = embed_dim
hidden_features = embed_dim * 4
out_features = in_features
fc1 = nn.Linear(in_features, hidden_features)
act = act_layer()
drop1 = nn.Dropout(0.5)
fc2 = nn.Linear(hidden_features, out_features)
drop2 = nn.Dropout(0.5)
  • Get the result from MLP layers
x = fc1(score)
x = act(x)
x = drop1(x)
x = fc2(x)
x = drop2(x)
x.shape
  • take out the cls token features.
cls = x[:,0]

Classifier Head πŸ†•

  • Create a simple classifier head and pass the class token features to get the predictions.
num_classes = 10 # assume 10 class classification
head = nn.Linear(embed_dim, num_classes)
pred = head(cls)
pred

Results published in the paper πŸ“ˆ

When trained on mid-sized datasets such as ImageNet without strong regularization, these models yield modest accuracies of a few percentage points below ResNets of comparable size.

Transformers lack some of the inductive biases inherent to CNNs, such as translation equivariance and locality, and therefore do not generalize well when trained on insufficient amounts of data.

However, the picture changes if the models are trained on larger datasets (14M-300M images). We find that large scale training trumps inductive bias.

  • The authors have mentioned that for smaller pre-training datasets (ImageNet) ViT-Large models underperforms than ViT-Base models. With large datasets (JFT-300M) ViT-Large models work well.
  • Vision Transformer models pre-trained on the JFT-300M dataset outperform ResNet-based baselines on all datasets while taking substantially less computational resources to pre-train.
  • The following table shows the result of ViT pretrained with the JFT-300M dataset and ImageNet-21k dataset. The columns show several models pretrained with different datasets. The rows are downstream tasks.
Table 2 from the paper

Train a simple ViT with PyTorch Lightning and timm πŸŽ†

import timm
import torch
import pytorch_lightning as pl
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
import torchmetrics
seed_everything(42, workers=True)
  • Let’s create a simple lightning model class.
class Model(pl.LightningModule):
"""
Lightning model
"""
def __init__(self, model_name, num_classes, lr = 0.001, max_iter=20):
super().__init__()
self.model = timm.create_model(model_name=model_name, pretrained=True, num_classes=num_classes)
self.metric = torchmetrics.Accuracy()
self.loss = torch.nn.CrossEntropyLoss()
self.lr = lr
self.max_iter = max_iter

def forward(self, x):
return self.model(x)
def shared_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = self.loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.metric(preds, y)

return loss

def training_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)
self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
self.log('train_acc', self.metric, on_epoch=True, logger=True, prog_bar=True)

return loss

def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch, batch_idx)
self.log('val_loss', loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
self.log('val_acc', self.metric, on_epoch=True, logger=True, prog_bar=True)

return loss

def configure_optimizers(self):
optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optim, T_max=self.max_iter)

return [optim], [scheduler]
  • We will define the transformations and download and load the CIFAR10 dataset next.
transform = transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 128trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=8)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=8)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  • Now we will initialize the Model class. Here, we are using a variant of ViT which takes images of size 224*224 and patch sizes are 16.
model = Model(model_name="vit_tiny_patch16_224", num_classes=len(classes), lr = 0.001, max_iter=10)
  • Let’s create a checkpoint callback to save the best checkpoint.
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='./checkpoints',
filename='vit_tpytorch_lightning6_224-cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}'
)
  • Almost done. Let’s create the trainer.
trainer = Trainer(
deterministic=True,
logger=False,
callbacks=[checkpoint_callback],
gpus=[0], # change it based on gpu or cpu availability
max_epochs=10,
stochastic_weight_avg=True)
  • Finally, let’s Train the model πŸ˜ƒ
trainer.fit(model=model, train_dataloaders=trainloader, val_dataloaders=testloader)

Related Resources

A special thanks to Prakash Jay for guiding me through this project.

--

--