Classification of medical images: Should you build a model from scratch or use transfer learning with state-of-the-art pre-trained models?

Olga Mindlina
Python’s Gurus
Published in
9 min readJun 15, 2024

The focus of this article is the methodology of training models on custom datasets. I’ve already discussed two solutions for the problem of classifying medical images:

1. Implementation and training of a CNN model from scratch — the link

2. Transfer learning using Vision Transformer (ViT) for classification — the link

Here, I return to the concept of transfer learning using state-of-the-art (SOTA) models pre-trained on large, well-known datasets. The ViT large model for classification is trained on the ImageNet-21k dataset, which contains approximately 14 million images. SOTA deep neural networks (DNNs) for classification are typically trained on the ImageNet dataset, a subset of ImageNet-21k, containing about 1.3 million images.

I use the dataset from Kaggle containing X-Ray chest images which I already used for training of my CNN model and for the model transfer from pre-trained ViT (see links above). Here I use images with two classes of pneumonia: “Pneumonia-bacteria” and “Pneumonia-virus”. The training set consists of 1000 images labeled “Pneumonia-bacteria” and 1000 images labeled “Pneumonia-virus,” totaling 2000 images. The test set consists of 1777 images labeled “Pneumonia-bacteria” and 555 images labeled “Pneumonia-virus,” totaling 2332 images.

In previous experiments with this dataset, I achieved an average accuracy of approximately 75.5% across the two classes using a custom CNN with three convolutional blocks. With transfer learning based on ViT, the average accuracy improved to around 78.5% (see links above). Note, that distinguishing between “Pneumonia-bacteria”-images and “Pneumonia-virus”-images can be challenging in some cases. The examples of the input cropped images:

Now, I am attempting to apply the transfer learning methodology, which was previously used with ViT, to well-known SOTA DNNs: AlexNet, VGG16, ResNet18, ResNet50, ResNet101, EfficientNetB0, EfficientNetB1, EfficientNetB2, and EfficientNetB3.

SOTA DNNs for classification. Overview

SOTA DNNs for classification mentioned above have different capacities and architectures. Some of them, such as AlexNet and VGG16, are relatively old. However, their high-level architecture is the same for all of them:

Fig. 1

The concepts and implementations of the Feature Extraction and Classifier blocks are different.

· The feature extraction blocks of AlexNet and VGG16 are classic CNNs where convolutional blocks are called sequentially, with the output of the current block serving as an input for the next block. These networks are not very deep, as the number of convolutions is small compared to newer DNN architectures. The number of trainable parameters in the feature extraction blocks of AlexNet and VGG16 is much lower than that in their classifier blocks.

· The feature extraction blocks of ResNet<X> (where <X> can be 18, 50, or 101 in this article) are based on residual connections between convolutional blocks. The residual connections add the original input to the transformed output. The number of trainable parameters in the feature extraction blocks of ResNet<X> is much higher than that in their classifier blocks. These are deep CNNs.

· EfficientNet_B<Y> (where <Y> can be 0, 1, 2, or 3 in this article) is a convolutional neural network architecture with a scaling method that uniformly scales all dimensions of depth, width, and resolution using a compound coefficient. It employs skip connections between convolutional blocks, which directly concatenate or merge features from different blocks. The number of trainable parameters in the feature extraction blocks of EfficientNet_B<Y> is higher than that in their classifier blocks. These are deep CNNs.

The table below shows some characteristics of these DNNs:

Table 1

Transfer learning methodology with the use of SOTA DNNs

For implementing the transfer learning I use the following idea: extract the full hidden state from the feature extraction block (Fig. 1) of the pre-trained model and use this hidden state as an input for training the own classifier block. Note that on the internet, there are examples of transfer learning where only the last fully-connected layer in the classifier is trained to convert the pre-trained embeddings to class probabilities. This solution works only for fine-tuning the model to define some classes of ImageNet, i.e., for training on a subset of ImageNet. I propose a transfer learning solution for a custom dataset that is very different from ImageNet. This is why I use the full hidden state and implement a classifier with a relatively large number of trainable parameters, as was done with ViT in this article.

Let’s start with the use of pre-trained models. I use PyTorch.

Installation:

!pip install torchvision
!pip install torchinfo

Common imports:

from PIL import Image
import torch
import torch.nn as nn
from torchinfo import summary

Cuda device setting:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Loading pre-trained models. The code below demonstrates the initialization for each model that I use for transfer learning:

# AlexNet:
from torchvision.models import alexnet, AlexNet_Weights
weights = AlexNet_Weights.DEFAULT
model = alexnet(weights=weights).to(device)
# don't foget to call 'eval' to use the model for inference
model.eval()

# VGG16:
from torchvision.models import vgg16, VGG16_Weights
weights = VGG16_Weights.DEFAULT
model = vgg16(weights=weights).to(device)
# don't foget to call 'eval' to use the model for inference
model.eval()

# ResNet<X>, <X> is one of the following: 18, 50, 101
from torchvision.models import resnet<X>, ResNet<X>_Weights
weights = ResNet<X>_Weights.DEFAULT
model = resnet<X>(weights=weights).to(device)
# don't foget to call 'eval' to use the model for inference
model.eval()

# EfficientNet_B<Y>, <Y> is one of the following: 0, 1, 2, 3
from torchvision.models import efficientnet_b<Y>, EfficientNet_B<Y>_Weights
weights = EfficientNet_B<Y>_Weights.DEFAULT
model = efficientnet_b<Y>(weights=weights).to(device)
# don't foget to call 'eval' to use the model for inference
model.eval()

The following lines of code demonstrate how to use the pre-trained model to obtain the final output:

img = <read PIL image>

preprocess = weights.transforms()
x = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(x)
# vector of class-probabilities:
prediction = outputs.squeeze(0).softmax(0)

We don’t need the final output; we need the hidden state from the feature extraction block. For AlexNet, VGG16, and EfficientNet_B<any_model>, it is quite straightforward: the model.features() function returns the output of the feature extraction block. The code block below shows how to prepare an input image for my classifier when using AlexNet, VGG16, or EfficientNet_B<any_model> for transfer learning:

img = <read PIL image>

preprocess = weights.transforms()
x = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
x = model.features(x)
x = x.squeeze(0)
# x.shape: x[0] - number of features, x[1], x[2] - picture sizes; 3 dimensions
x = torch.reshape(x, (x[0], x[1]*x[2])) # 2 dimensions in the result

Note: I use x.squeeze(0) to remove the batch dimension for a single image, as I assume it will be sent to torch-DataLoader, which adds the batch dimension to image batches. Additionally, I reshape x to combine the image size into a single dimension.

For ResNet<any_model>, I need to call a sequence of the model’s functions to obtain the output of the feature extraction block. The code block below shows how to prepare an input image for my classifier when using ResNet<any_model> for transfer learning:

img = <read PIL image>
preprocess = weights.transforms()
x = preprocess(img).unsqueeze(0).to(device)

with torch.no_grad():
x = model.conv1(x)
x = model.bn1(x)
x = model.relu(x)
x = model.maxpool(x)

x = model.layer1(x)
x = model.layer2(x)
x = model.layer3(x)
x = model.layer4(x)

x = x.squeeze(0)
# x.shape: x[0] - number of features, x[1], x[2] - picture sizes; 3 dimensions
x = torch.reshape(x, (x[0], x[1]*x[2])) # 2 dimensions in the result

Note: I use x.squeeze(0) to remove the batch dimension for a single image, as I assume it will be sent to torch-DataLoader, which adds the batch dimension to image batches. Additionally, I reshape x to combine the image size into a single dimension.

The table below shows how the resolution of the input image is transformed into the resolution of the features tensor and subsequently into the resolution of the input for my classifier, depending on the model:

Table 2

Let’s look at the classifier that processes the features obtained by the pre-trained model. For all models, I apply the same architecture, adapted to the input sizes and maintaining approximately the same number of trainable parameters (~440,000). The following code block demonstrates the implementation of the classifier:

class ChestClassifierTransfer(nn.Module):
def __init__(self, num_classes, input_size_1, input_size_2, hidden_size):
super(ChestClassifier, self).__init__()
self.num_classes = num_classes

self.ln1 = nn.Linear(input_size_2, hidden_size)
self.relu = nn.ReLU(inplace=True)
self.ln2 = nn.Linear(input_size_1*hidden_size, self.num_classes)
self.dropout = nn.Dropout(p=0.2)

def forward(self, x):
x = self.ln1(x)
x = self.relu(x)
x = self.dropout(x)

x = nn.Flatten()(x)
x = self.ln2(x)

return x

net = ChestClassifierTransfer(2, <input_size_1>, <input_size_2>, <hidden_size>).to(device)

<input_size_1> and <input_size_2> are the input resolution values for my classifier. Refer to the right column of Table 2 for more details.

<hidden_size> is a parameter that maintains the number of trainable parameters in my classifier at approximately 440,000.

The table below shows the <hidden_size> values depending on the model:

Table 3

A few words about the classifier architecture: I found that the model with approximately 440,000 parameters showed the best performance on the test set. Decreasing or increasing the number of parameters led to lower performance. Note that for transfer learning based on ViT, I used a classifier with a similar architecture and a compatible number of parameters as described above. For more details, refer to the article.

An example summary of the classifier that uses EfficientNet_B1 features as input:

summary(model=net, input_size=(1, 1280, 64), col_names=['input_size', 'output_size', 'num_params', 'trainable'])

I use the Adam optimizer with a learning rate of 0.001. After training, I selected the checkpoints that showed the best performance on the test set for all models.

The results of transfer learning: Comparison of different models

In this section, I compare the results of all classifiers: a previously trained custom model based on 3 convolutions (3CNN), a model based on ViT transfer (discussed here), and all models based on SOTA DNNs transfer. The common criterion I use for comparing the models’ prediction quality is the F-measure. The F-measure is a metric with a value in the range of 0 to 1; the closer the value is to 1, the better the model’s performance.

The table below shows the prediction quality for each model. Models with performance lower than that of 3CNN are highlighted in gray, while yellow and green mark models with performance higher than 3CNN. Green highlights models with the best performance:

Table 4

Finally, let’s examine the execution time of each model in an environment with 1 GPU, where all inferences are performed on the GPU. I measure the execution time of image preprocessing (assuming the image is already loaded in PIL Image form) followed by the classification of the prepared image. I found that the average time for “preprocessing + classification” for the 3CNN model in my environment is 0.0045 seconds. I set the execution time of the 3CNN model as a conventional unit, measuring the execution time of other models relative to that of the 3CNN model. This means the execution time of the 3CNN model is equal to 1. For example, if the execution time of model “NN” is equal to 3, it means that model “NN” is 3 times slower than the 3CNN model.

The table below shows the execution time for each model in conventional units:

Table 5

Interestingly, the models with the best prediction quality have approximately the same execution time (highlighted).

Conclusion

Transfer learning using ViT for classification or the most effective SOTA DNNs is a way to achieve the best prediction quality for a classifier trained on a custom dataset, such as medical images. The loss in model speed is the price paid for higher quality.

Python’s Gurus🚀

Thank you for being a part of the Python’s Gurus community!

Before you go:

  • Be sure to clap x50 time and follow the writer ️👏️️
  • Follow us: Newsletter
  • Do you aspire to become a Guru too? Submit your best article or draft to reach our audience.

--

--