The MedMNIST dataset, v1.

MultiResViT

Abhirath Anand
cvrs.notepad
Published in
8 min readAug 30, 2021

--

Introduction

In the recent past, self-attention and in particular, vision transformer based models have attracted interest in the computer vision community, with attempts being made to modify the architecture of the originally proposed ViT to make it more suitable for learning visual representations and being able to achieve good performance even with minimal training data. While data-efficient ViT or DeiT explores CNN teacher distillation to solve this problem, but this results in an increase in the complexity of the network. Other networks like Early Convolutions help Transformers see better or Visformer still underperform in comparison to CNN based models.

The model architecture of the original ViT model. Screenshot taken by author. The source is Figure 1 on the paper “An Image is Worth 16x16 Words.”

We experimented with using vision transformers for medical image classification, in particular on datasets that were smaller in both size and resolution. Our proposed model architecture, named MultiResViT, comes very close to matching state-of-the-art on the dataset, and in the process we make several important observations about improving the performance of vision transformer models trained on low data. In addition, we were able to achieve these results despite being extremely constrained in terms of compute availability. Our code is made public in our repository, and we hope that this can serve as a useful starting point for others interested in the field of medical imaging and specifically, transformers for image classification in the context of the medical domain.

Dataset

The MedMNIST dataset is a set of 10 pre-processed medical open datasets. MedMNIST is standardised for classification tasks on lightweight 28 × 28 images, which requires no background knowledge in the field of medical imaging — thus making it extremely convenient for benchmarking deep learning models for various medical image classification tasks. Due to limitations on compute, we chose to focus only on three datasets in particular:

  1. PneumoniaMNIST, for binary classification of 5,856 chest X-Rays to determine if a person is normal or infected;
  2. RetinaMNIST, a dataset of 1,600 retina fundus images. The task is ordinal regression for 5-level grading of diabetic retinopathy severity;
  3. DermaMNIST is based on HAM10000, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. The dataset consists of 10,015 dermatoscopic images labeled as 7 different categories, as a multi-class classification task.

Due to compute restrictions, we used only the train and the test splits from the original datasets, with just 0.001 of the train dataset for the validation dataset. After normalisation with a mean and standard deviation of 0.5, the training data is further augmented with random horizontal flips and random rotations of 10 degrees.

Evaluation

The original metrics on the MedMNIST dataset(v1). Screenshot by the author. The source is Table 2 from the paper “MedMNIST Classification Decathlon: A Lightweight AutoML Benchmark for Medical Image Analysis”

Given that this is a medical imaging dataset, there is a large amount of inherent imbalance, and thus the metrics have to be interpreted very carefully. We chose to compare raw accuracy scores (ACC) and area under ROC curve (AUC) scores on the datasets, given that these are the metrics stated in the original paper as well. Since AUC is less sensitive to class imbalance than ACC, we considered it best to compare both metrics together to determine the overall model performance.

Architecture

Motivation

NesT is a model that shows good performance on the CIFAR-10 dataset, which has a considerably lower number of training images than is usually sufficient for other vision transformers, and also is of size 32× 32, comparable to the 28 × 28 size of the MedMNIST dataset. It is also less sensitive to data augmentation and shows faster convergence than DeiT. These properties make it desirable for training on datasets which do not have a large number of images.

NesT base model architecture. Screenshot by the author. The source is Figure 1 from the paper “Aggregating Nested Transformers”.

While self-attention has proven to be remarkably general in the field of tasks that it can impact, a large amount of data is needed achieve comparable results with CNNs. In papers such as LeViT, LocalViT and TransMed, an argument is made for injecting convolutional elements into ViT-based networks so as to enhance the inductive biases of the network with respect to image data. In particular, Early Convolutions help Transformers see better demonstrates the usefulness of early convolutions in that the model is more robust to the choice of learning rate and weight decay and also converges faster.

LeViT base model architecture. Screenshot by the author. The source is Figure 4 from the paper “LeViT: a Vision Transformer in ConvNet’s Clothing for Faster Inference”.

Super-resolution (SR) is a method for enhancing the resolution of an imaging system. Advancements in deep learning have proven it possible to train super-resolution networks as given in Enhanced Deep Residual Networks for Single Image Super-Resolution. This paper proposes both x2 and x4 models for increasing the resolution of a given image.

EDSR base model architecture. Screenshot by the author. The source is Figure 3 from the paper “Enhanced Deep Residual Networks for Single Image Super-Resolution”.

Base model

Our thought process revolved behind combining these two models — NesT and LeViT — to leverage both spatial information by early convolutions as well as allowing a model that was tried and tested on low-resolution data to extract feature maps. In addition, we used a pre-trained super-resolution network (EDSR x4) for upsampling the 28 × 28 sized images to 112× 112.

The architecture then consisted of feeding the lower resolution images to the NesT model, the higher resolution images to the LeViT model, and then concatenating the feature map outputs from them with a fully connected layer followed by a softmax. The exact configurations for each of the models is made public in our repository. We based our code for the vision transformers on the code from the repository by lucidrains for the LeViT and NesT models. For most benchmarks, we used RetinaMNIST to establish how well the model was doing and check for any potential issues, and then tested it out on the other two datasets.

Code for the base MultiResViT model.

Enhancements

However, we realised that the base model did not give us the results that we expected — the model saturated around 50% ACC and 70% AUC score for the RetinaMNIST dataset, which is about 5% lower in ACC and 6% lower for the AUC score. And inspite of using super-resolved images on both branches, there was no significant improvement in the results. Further, using an ensemble-like approach (training only a specific number of classes on each branch) revealed that the NesT model was not learning well on the data.

So we decided to do away with the NesT branch and instead use LeViT on both branches to see if there were any improvements, and the change was marked — the ACC went up to 51%, and the AUC went up to 72%. When using super-resolved images on both branches, this further improved to 53.5% and 73% respectively.

But we were unsatisfied with this, so we decided to try training each branch of the model on a limited number of classes, and then concatenating with a fully-connected layer. The rationale behind this was trying to deal with the class imbalance, and hopefully improve overall model performance in the process. This resulted in an improvement in the overall scores as well, and we were able to touch 54% ACC and 77% AUC for RetinaMNIST. This same model touched 73.37% ACC, 89.60% AUC for DermaMNIST. For RetinaMNIST and DermaMNIST, we used models with 3 classes for one branch and the remaining classes for the other branch. The rationale behind this was the weights of the respective classes for the datasets, which we scrutinised before making the choice.

For PneumoniaMNIST, we used branches with a single class output each and managed to get 87.82% ACC, 84.70% AUC. This is quite underwhelming in comparison to the expected results, but we think that it might possibly be due to PneumoniaMNIST being grayscale images and thus less context available in the image.

We note that due to limitations on compute, our model is just 2,227,496 parameters, yet shows quite remarkable performance on these datasets given the relatively low availability of data.

Code for the enhanced MultiResViT model as used for RetinaMNIST.

Training procedure

Due to limitations on compute available, we decided to first super-resolve the images and store the dataset in h5 files for easy read-write access. We then used custom dataset pipelines to load the data into PyTorch dataloaders and iterate through it for training.

Custom dataset pipeline written to load the original and super-resolved images.

The optimiser used was Adam with a learning rate of 0.00075, and the models converged in 100 epochs. The learning rate was reduced twice, once after 50 epochs and then again after 75 epochs.

Inferences

It can be seen from the experiments that the lack of data and the low resolution prove to be a significant hurdle for the training of transformer models given that they lack the inductive biases that traditional convolutional models introduce. This is apparent because of the enhanced performance of LeViT, which uses early convolutional layers to extract embeddings, in contrast to a model like NesT which uses convolutional projections only in later steps. The work by FAIR in Early Convolutions help Transformers see better is thus seen to hold true.

The other we were able to infer was the effectiveness of using models to learn a limited number of classes and then concatenating with a fully connected layer. This not only improved the overall performance but also gave us better accuracy reports, showing that ensemble-like approaches can be used for dealing with class imbalanced datasets to some extent. We also see that the 28 × 28 sized images do not prove to be very conducive to learn good representations for vision transformers, especially given the relative difficulty of medical image classification in contrast to a simpler dataset like MNIST.

From these experiments, we concluded that there is yet work to be done to further refine vision transformers for low-data, low-resolution datasets, especially in the context of medical imaging. However, we were able to understand the nature of vision transformers much better, including combining vision transformer models and using ensemble-like techniques to improve performance on class-imbalanced datasets.

--

--