Vision Transformer for classification on medical images. Practical uses and experiments.

Olga Mindlina
12 min readJan 7, 2024

--

The focus of this article is Vision Transformer (ViT) and its practical applications for real life problems. I discuss again the task of medical image classification which I’ve already solved using CNN and presented here. Now I present the solution based on ViT. The Transformer architecture has become the de-facto standard for natural language processing tasks. What is Vision Transformer (ViT)? ViT architecture is based on image representation as a set of patches. Image patches are non-overlapping image blocks with the size 16x16 pixels. For example, in an image with the resolution 224x224 there are (224 / 16) * (224 / 16) = 14 * 14 = 196 patches. Image patches are treated the same way as tokens (words) in an NLP application. ViT represents each patch as a flattened linear projection of its pixels and operates with patch-embedding vectors with length=768 (16x16x3 = 768). Picture below shows the full schema of ViT (the picture is from the article):

The main parts of the transformer are following: patch + position embeddings preparation, Encoder, Pooling (Multi-Layer Pooling Head).

1. Patch + position embeddings are formed from the input image pixels as a matrix with the size 196 x 768 (vector with 768 values in each patch position, 196 patches for an image size 224 x 224). In the zero-position, a randomly-initiated vector with 768 values is added, so patch + position embeddings is a matrix with a size 197 x 768.

2. Encoder contains a sequence of Multi-Head Attention blocks followed by normalization layers and Multi-Layer Pooling blocks. Transformer Encoder is a main part of ViT which trains similarity between patches according to their class affiliation. It contains a sequence of linear, normalization and activation layers. The embedding-matrix with the size 197 x 768 is transformed to express the interaction between patches and to express their class values. Zero-position row of this matrix is a class-token (vector of 768 values) which is used as an input of following Pooling block.

3. Pooling block transforms finally the class-token (vector of 768 values) to the output vector with embeddings for classes of interest. Linear and activation layers are also used in this block.

ViT from Hugging Face: the implementation understanding in practice

Let’s look at the base ViT model from Hugging Face using the following code blocks:

Installation:

!pip install torchvision
!pip install torchinfo
!pip install -q git+https://github.com/huggingface/transformers.git

Imports:

from PIL import Image
from torchinfo import summary
import torch

Google drive mounting (for google colab):

from google.colab import drive
drive.mount('/content/gdrive')

Cuda device setting:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In the code below I look into ViT base model:

from transformers import ViTConfig, ViTModel
configuration = ViTConfig()
print(configuration)

The default base model configuration is following:

ViTConfig {
"attention_probs_dropout_prob": 0.0,
"encoder_stride": 16,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"image_size": 224,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"model_type": "vit",
"num_attention_heads": 12,
"num_channels": 3,
"num_hidden_layers": 12,
"patch_size": 16,
"qkv_bias": true,
"transformers_version": "4.37.0.dev0"
}

Changing the fields of the configuration we may create a custom ViT model. Let’s try the default ViT base model:

model = ViTModel(configuration).to(device)
model.eval()

In the output we see all ViT base model layers:

Fig. 1

Model summary:

summary(model=model, input_size=(1, 3, 224, 224), col_names=['input_size', 'output_size', 'num_params', 'trainable'])

ViT base model has a large number of parameters — more than 86 million.

Let’s look at the model output structure. I sent a randomly-generated fake image to the model:

x = torch.randn((3, 224, 224))
x = torch.unsqueeze(x, 0)

y = model(x.to(device))

print(y.pooler_output.shape)
print(y.last_hidden_state.shape)

And we can see in the output:

torch.Size([1, 768])
torch.Size([1, 197, 768])

The final output of ViT base model contains two parts: last_hidden_state with the shape (batch_size, 197, 768) which is the output of the sequence model.embeddings + model.encoder + model.layernorm (see Fig. 1) just before model.pooler part; pooler_output with the shape (batch_size, 768) which is the model.pooler output. In the input of model.pooler block there is a zero-position row of the normalized last_hidden_state matrix, which obtained on the previous step. The picture below illustrates the equivalence of step-by-step blocks calling (described above) and getting the model output by calling the whole model once:

Fig. 2

If we run the left code and the right code in Fig. 2 with the same input tensor x, we see the same output tensors on printing.

It is important to understand ViT blocks and their outputs structure to develop solutions based on transfer learning using ViT. Model.pooler block is changed to a custom block and this block is trained using the inference of ViT model on previous blocks as an input.

Two pre-trained ViT models for image classification are available in Hugging Face:

1. Pre-trained on ImageNet-21k (a collection of 14 million images and 21k classes);

2. Fine-tuned on ImageNet (also referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes).

The architecture of fine-tuned on ImageNet classifier for 1000 classes (ViTForImageClassification) contains model.classifier block instead of model.pooler block, just the following Linear layer:

(classifier): Linear(in_features=768, out_features=1000, bias=True)

The input of this layer is a zero-position row of the normalized last_hidden_state matrix.

ViT vs CNN

1. CNN model gets all local features from a picture and regards the set of features in the whole to classify the input image. It is trained to calculate an image class-label based on all features. ViT regards a picture as a set of patches and takes in account the patches positions. It is trained to calculate the similarity between patch-embeddings and to decide what is the class-label for the “similar” patches, i.e. ViT architecture contains the concept of segmentation.

2. ViT model has a large number of parameters (in the summary above → 86 million) and require a large dataset for a good performance. CNN model might be adopted to datasets with different sizes and may require relatively small number of parameters to get a good performance.

ViT doesn’t show a good performance on small custom dataset if it is trained from scratch. A use case for small custom datasets is transfer learning using ViT inference of pre-trained on large datasets models.

ViT for X-Ray chest images classification. Practical experiments

In this section I return to the task solved by me with CNN and described here. I use the same dataset with X-Ray chest images. This dataset contains three classes of images:

I use unified cropped images with chest areas. The examples of the cropped images (“Normal (no pneumonia)”, “Pneumonia-bacteria”, “Pneumonia-virus” — from the left to the right):

The dataset is split into a training set and a test set. The training set contains 3000 images — 1000 “Normal (no pneumonia)”, 1000 “Pneumonia-bacteria”, 1000 “Pneumonia-virus”, selected at random from their respective groups. The rest of the images composes the test set, which thus contains 2908 images — 576 “Normal (no pneumonia)”, 1777 “Pneumonia-bacteria”, 555 “Pneumonia-virus”.

CNN vs ViT for 2-class classifier “Normal (no pneumonia)” / “Pneumonia (bacteria or virus)”

I’m solving the following task using the X-Ray data: create a system which can determine whether an input X-ray chest image belongs to class “Normal (no pneumonia)” or to class “Pneumonia (bacteria or virus)”, i.e. 2-class classifier, using ViT. I’ve already implemented the solution with CNN containing 3 convolution blocks, the model has been described here. This model shows the best results among CNN-models for this dataset. The summary of 3-convolution model is following:

The model contains 348,050 parameters — far fewer parameters than in the ViT model. Note, that for CNN model I use images with a resolution 256x256.

Here I try the ViT model, pre-trained on the ImageNet-21k dataset and fine-tune it for X-Ray images.

Model1. “Small” linear classifier after processing of an input image by ViT.

First, I try the simplest solution — one Linear layer taking as its input zero-position row from the last_hidden_state matrix — vector of 768 values. This kind of final-fitting is applied for the image classifier on 1000 classes on ImageNet dataset.

Load the pre-trained ViT model + image processor:

from transformers import ViTConfig, ViTModel
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)

The code below shows pre-processing for one image which initially is in the form of PIL Image into the class-token vector of 768 values from the pre-trained ViT model. This vector is an input of my Linear classifier:

img = <load PIL Image>

inputs = image_processor(img, return_tensors="pt").to(device)

with torch.no_grad():
outputs = model(**inputs)

img = outputs.last_hidden_state
img = img[:, 0, :]

Note: I suppose that two-dimensional img will be sent to torch DataLoader, which adds batch dimension to image-batches.

A batch of processed input images with the shape (batch_size, 1, 768) is sent to the following model:

class ChestClassifier(nn.Module):
def __init__(self, num_classes):

super(ChestClassifier, self).__init__()
self.num_classes = num_classes
self.ln = nn.Linear(768, self.num_classes)

def forward(self, x):
x = nn.Flatten()(x)
x = self.ln(x)
return x

model1 = ChestClassifier(2).to(device)

The summary of this model:

summary(model=model1, input_size=(1, 1, 768), col_names=['input_size', 'output_size', 'num_params', 'trainable'])

“Small” classifier model contains only 1,538 parameters.

I use Adam optimizer and learning rate = 0.001. 3000 images are used for the training, and 2908 images are used for the test. I make training batches balanced (~50% images for each class). In the picture below there is a comparison of results for CNN architecture (they were obtained and presented here) and for the model1 above. In the results below “Class 0” means “Normal (no pneumonia)” and “Class 1” means “Pneumonia (bacteria or virus)”. For both models I’ve chosen the best checkpoints:

The results of ViT fine-tuning with “small” linear classifier are definitely worse than the results of CNN architecture. I see the following reasons of these results: medical images are very different from ImageNet data on which ViT model was trained and the number of trainable parameters of my “small” linear classifier is not enough to make the transfer learning results better than CNN model results.

How to improve the model? First, nothing prevents me from using the whole pre-trained patch-positional state — the whole last_hidden_state from ViT output — for fine-tuning the classifier. Second I can try a more complicated classifier model with a larger number of trainable parameters.

Model2. “Large” linear classifier after processing of an input image by ViT.

In comparison with model1 I change pre-processing of the input PIL Image to get the whole transposed last_hidden_state matrix from pre-trained ViT model. This matrix forms an input for my classifier model:

img = <load PIL Image>

inputs = image_processor(img, return_tensors="pt").to(device)

with torch.no_grad():
outputs = model(**inputs)

img = outputs.last_hidden_state.permute(0, 2, 1)
img = img.squeeze()

Note: I use img.squeeze() to remove batch dimension for a single image because I suppose that it will be sent to torch-DataLoader, which adds batch dimension to image-batches.

A batch of processed input images with the shape (batch_size, 768, 197) is sent to the following model:

class ChestClassifierL(nn.Module):
def __init__(self, num_classes):

super(ChestClassifierL, self).__init__()
self.num_classes = num_classes
self.ln1 = nn.Linear(197, 256)
self.relu = nn.ReLU(inplace=True)
self.ln2 = nn.Linear(768*256, self.num_classes)

def forward(self, x):
x = self.ln1(x)
x = self.relu(x)
x = nn.Flatten()(x)
x = self.ln2(x)
return x

model2 = ChestClassifierL(2).to(device)

The summary of this model:

summary(model=model2, input_size=(1, 768, 197), col_names=['input_size', 'output_size', 'num_params', 'trainable'])

“Large” classifier model contains 443,906 parameters.

I use Adam optimizer and learning rate = 0.001. The picture below shows the comparison of results for CNN architecture (they were obtained and presented here) and for the model2 above. In the results below “Class 0” means “Normal (no pneumonia)” and “Class 1” means “Pneumonia (bacteria or virus)”. For both models I’ve chosen the best checkpoints:

ViT fine-tuning with the “large” classifier shows the better performance than CNN! The reason of this result is not only increasing of the number of trainable parameters but taking in account the whole patch-positional information. The segmentation concept is important for medical images because they may contain abnormality areas which are special for a particular problem.

Below I demonstrate the positive trend of ViT using on another classifier — the classifier for different kinds of pneumonia: “Pneumonia-bacteria” and “Pneumonia-virus”. I have 1000 images with “Pneumonia-bacteria” plus 1000 images with “Pneumonia-virus” in the training set and use 1777 images with “Pneumonia-bacteria” plus 555 images with “Pneumonia-virus” for the test. So, the training set contains 2000 images and the test set contains 2332 images. I make a comparison of the same CNN architecture with 3 convolution blocks and the same ViT plus model2 combination as for the classifier “Normal (no pneumonia)” / “Pneumonia (bacteria or virus)” above. In the results below “Class 0” means “Pneumonia bacteria” and “Class 1” means “Pneumonia virus”. For both models I’ve chosen the best checkpoints:

I’ve already showed in my previous post that it is difficult to distinguish different kinds of pneumonia with a good quality. In any case, the results above demonstrate the better performance for ViT fine-tuning solution in comparison with CNN.

Model3. Fine-tuning ViT for custom input resolution.

In all examples discussed above I compare the results of CNN model trained on input images with a resolution 256x256 vs ViT fine-tuning results, where ViT pre-trained model requires input images with a resolution 224x224. In this article I’ve found the solution for transfer learning on higher resolution: the output size of the pre-trained model should be changed according to the embedding positions for higher resolution, after this it should be sent to the model making fine-tuning with the new resolution. A 224x224 image has 196 patches and the ViT last_hidden_state resolution is 197x768. A 256x256 image has 256 patches and the ViT last_hidden_state resolution should be 257x768. So, to fine-tune ViT for the input resolution 256x256 I need to resize last_hidden_state matrix to the resolution 257x768 and continue training with this matrix.

Let’s try it in practice. An input PIL Image pre-processing will be the following:

img = <load PIL Image>

inputs = image_processor(img, return_tensors="pt").to(device)

with torch.no_grad():
outputs = model(**inputs)

img = outputs.last_hidden_state.permute(0, 2, 1)
# new patch-position embeddings resolution
img = transforms.Resize((768, 257))(img)
img = img.squeeze()

Note: I use img.squeeze() to remove batch dimension for a single image because I suppose that it will be sent to torch-DataLoader, which adds batch dimension to image-batches.

A batch of processed input images with the shape (batch_size, 768, 257) is sent to the following model:

class ChestClassifierL256(nn.Module):
def __init__(self, num_classes):

super(ChestClassifierL256, self).__init__()
self.num_classes = num_classes
self.ln1 = nn.Linear(257, 256)
self.relu = nn.ReLU(inplace=True)
self.ln2 = nn.Linear(768*256, self.num_classes)

def forward(self, x):
x = self.ln1(x)
x = self.relu(x)
x = nn.Flatten()(x)
x = self.ln2(x)
return x

model3 = ChestClassifierL256(2).to(device)

The summary of this model:

summary(model=model3, input_size=(1, 768, 257), col_names=['input_size', 'output_size', 'num_params', 'trainable'])

I’ve tried model3 for 2-class classifier “Normal (no pneumonia)” / “Pneumonia (bacteria or virus)”. Picture below shows the comparison of results of model2 for input resolution 224x224 and of model3 for input resolution 256x256. In the results “Class 0” means “Normal (no pneumonia)” and “Class 1” means “Pneumonia (bacteria or virus)”. For both models I’ve chosen the best checkpoints:

The results of changing resolution are more noticeable at resolutions very different from 224x224.

Conclusion

A proper combination of ViT inference and a fine-tuning model might improve the performance of classifier even on the very specific dataset like medical images.

--

--