Exploring the Benefits of Transfer Learning While Building Brain Tumor Classifiers

A comparative analysis of DL techniques ⚖️

Vinaya Sharma
Towards Dev

--

AI is huge. There are a lot of different models, parameters, and even sub parameters to choose from. For image classification tasks, CNN-based models from scratch, and with transfer learning are currently the two most popular selections.

Transfer learning is a popular ML technique that applies previously learnt data onto a new dataset. Both models from scratch and transfer learning models can be used in image classification, and today we have applied these models in brain tumor classification tasks. When comparing these two models, transfer learning typically returns higher accuracies as transfer learning starts off with better initial results meaning they can be trained for less. Transfer learning models also achieve faster convergence as they already understand many features, and generalize well through the vast amount of data they have been previously trained on.

This article will guide you through coding brain tumour classifiers, but if you want to learn more about the theory first, check out this one! These will be 2 exciting builds in which you get to experience the benefits of transfer learning for yourself, but first, let's start with some background!

Brain tumor diagnosis

A major healthcare concern to date is accurately and efficiently diagnosing patients through scans and images. Currently, highly skilled doctors are required to diagnose and classify medical images and is a largely manual process 🏭.

According to Kang Zhang, a professor at MUST’s Faculty of Medicine, “It will take a senior radiologist at least 20 minutes to look at a tomography (CT) scan which comprises anywhere from 200 to 400 images,”. Whereas an AI-based CT scan reading takes about only 20 seconds. All of this without compromising accuracies and in many cases, actually improving them.

How does this work 👷‍♀️

AI is being integrated into medical imaging services, particularly with Radiology scans, X-rays, C-T scans, and MRIs. Deep neural networks (which are just mathematical functions) take as input the images and output a diagnosis. 📸 + 🤖 → 📝

Behind the scenes, the computer vision goddess Convolutional Neural Networks (CNNs) do most of the heavy lifting as they identify key features in the image, and then linear layers connect activated filters with a classification.

Version 1 — Building a CNN from scratch 👩‍💻

  1. First things first, let’s import libraries.
import numpy as np
import torch
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import pathlib
%config Completer.use_jedi = False

The backbones of our code:

  • numpy: python library for manipulating and computing mathematical functions on matrices
  • torch: library for PyTorch. We will be using Pytorch, a popular deep-learning framework for this build. Includes common DL mathematical functions.
  • Matplotlib: python visualizations 📊
  • torch vision: PyTorch helper library for computer vision problems. Includes datasets, models, transformations etc…
  • os: manipulate data on our operating system ⚙️

2. Now let’s embrace our inner detectives — time to analyze and organize the data 🕵️‍♀️

We will be using the Kaggle Brain Tumor Classification (MRI) dataset. This dataset contains 3264 files, is separated into train and test folders, and is further classified into 3 types of brain tumors: Glioma, Meningioma, and Pituitary.

All jpg’s need to be converted into tensors so that we can do some exciting math to extract a classification. We do this with transformations! We need all data to be the same dimensions, so I resize all images to 128x128 in order to save computational time. I also convert the images to grayscale, as grayscale data was producing better accuracies during model tests. 🌚

I have set 2 transformations, with the second adding some random auto contrast and colour jitters. This technique of data augmentation effectively doubles our data and adds variance, allowing our model to generalize better.

# grab data 
test_path = "../input/brain-tumor-classification-mri/Testing"
train_path = "../input/brain-tumor-classification-mri/Training"

# transformations
transforms1 = transforms.Compose(
[
transforms.Resize((128, 128)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
])
transforms2 = transforms.Compose(
[
transforms.RandomAutocontrast(),
transforms.ColorJitter(),
transforms.Resize((128, 128)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor()
])

3. Now we can download the transformed data 💪

I concatenate the 2 sets of testing data that received different transformations, and then split all of the training data into a training and validation set. The validation set will be used to select a model that generalizes best, while the testing set will grade our models’ accuracy on never seen before data 👀.

# download data 
train_data1 = datasets.ImageFolder(train_path, transform=transforms1)
train_data2 = datasets.ImageFolder(train_path, transform=transforms2)
test_data = datasets.ImageFolder(test_path, transform=transforms1)

image_datasets = torch.utils.data.ConcatDataset([train_data1, train_data2])
valid_data, train_data = torch.utils.data.random_split(image_datasets, [740, 5000])
# create data loaders
batch_size = 32
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True)

Check it out, we are done with data pre-processing 🤯

Now we can begin building our model that not only detects brain tumours but accurately classifies them into 1 of the 3 most popular varieties.

Let’s get our hands dirty with some architectural engineering! 👩‍🎨

Finding the right model is tricky. It can get very overwhelming when you think of all of the layer options (convolutions, pooling, activations, dropout etc.…), not to mention possible hyperparameter tunings (stride, kernel, input and output).

While building out this project I read a lot of articles, watched a ton of videos and tested out a bunch of techniques. I’ve come up with a 5 step process that helps define an architecture.

Behind the scenes: the art of choosing your CNN architecture 🖌

  1. Decipher your data and problem statement
  • There is no one size fits all solution for ML models. Each problem is different, all data is different, and all biases are different.
  • Is your problem working with text, tables, or images? What format do you want your inputs and outputs in? Are you predicting, classifying, or segmenting?

2. Gain experience and understand conventions

  • Once you understand your problem at hand, you must understand the basics.
  • For example, today we are building a CNN classification model (as identified in step 1). Before we begin we must break down why CNNs are used, and their benefits. Once you understand the basic principles you will be able to imagine a vague mental model of the architectural flow.
  • For a CNN the basic principles are to maintain wide and shallow feature shapes in the initial stages of the network and then make it narrower and deeper towards the end.

3. Inspo

  • Once you have a general opinion, it is a good idea to follow the flow outlined by conventional models.
  • In the case of a CNN, checking out popular models such as Restnet, AlexNet, and GoogleNet are good a idea. Most CNN models follow the basic structure of convolution → activation → normalization → alternating pooling layers → fully connected layer.
  • Read research papers tackling your problem, explore the best and worst models currently in the industry, and learn from both the good and bad.

4. Choose depth (until overfitting)

  • Deeper neural nets typically perform better. Add more layers while evaluating when the model begins to overfit. Graphically visualizing the loss on the training and validation sets will return a good metric of when to stop.

5. Search (automated search method to test multiple models)

  • Finally, what you’ve been dreading to hear: search, test, try, and experiment. Continual small tweaks in your model can add up to major improvements. My first model started off with a 55% accuracy and I was able to bring this up to more than 75%.
  • Hack: Run models in a loop and keep track of model performances within your code. Leave this running on your GPU while you continue with other tasks. Saves time, and automates the search process.

This is what I have found works best for me. If you have more convenient or valuable methods, definitely let me know! And there we have it, once we decide on an architecture we can get back to the code. 🎉

4. Coding the model 🤩

The model architecture I have selected consists of 7 layers each containing a convolution, ReLU activation, 2D batch normalization and max pool. Followed by 4 fully connected layers each followed by a batch normalization.

  • Convolutional layer — nn.Conv2d(….) : this layer takes an input and output signifying the current and outputted image depth. Each layer in depth represents the capability of a certain feature’s presence at any given pixel. For each layer in depth, the model decides if a certain feature is present. the kernel size represents the dimensions of the features being checked, the stride represents the number of pixel shifts each time, and the padding is layers of zeros added to the sides.
  • Rectified Linear Unit — nn.ReLU(): ReLU is the activation function I’ve used and is defined by y = max(0, x). ReLU is commonly used due to its ability to add non-linearity and its beneficial cheap compute. With this non-linearity, we are able to build arbitrarily shaped curves on the feature plane.
graph of ReLU function
  • Batch Normalization — nn.BatchNorm2d(…): Normalization collapses all pixel values to be between 0 and 1. A large range or variances in pixel values makes it difficult for the model to converge. By scaling all pixels between these values, the model can train faster. One can normalize the data during the transformation process but after the first layer the data can saturate again. That’s where batch normalizations come to the rescue! 🦸
  • Max Pool — nn.MaxPool2d(…): While the convolutional layers are increasing the data depth, the max pooling layer downsamples the image by reducing its width and height dimensions. Similar to convolutional layers, pooling layers have strides and kernels which dictate their feature maps. The kernel slides over the data and takes the maximum values from its view. This effectively extracts the most important features. Hence reducing compute, and preventing overfitting.
# convolution network 
class conv_layer(nn.Module):
def __init__(self, inp, out):
super(conv_layer, self).__init__()
self.conv1 = nn.Conv2d(inp, out, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.norm = nn.BatchNorm2d(out)
self.pool = nn.MaxPool2d(kernel_size=2,stride=2)

def forward(self, x):
x = self.pool(self.norm(self.relu(self.conv1(x))))
return x

Once the model passes through multiple convolutional layers it is put through fully connected layers before predicting an output.

The FC layers are a little more simple. Each neuron consists of weights and biases which are learnt during the training process to classify the features extracted during the first part of the network. These layers just need the input and output sizes as arguments.

Calculating image dimensions after convolutions

In order to calculate the total dimensions after the various 7 layers of convolutions you can use this formula:

W — width, F — filter size, P — padding, S — stride

Be sure that the logits (final output value — raw prediction values) for our problem is 4, as we are classifying the brain tumor into 4 categories. 🗂

# model network 
class network(nn.Module):
def __init__(self):
super(network, self).__init__()
self.input_norm = torch.nn.BatchNorm2d(1, affine=False)
self.layer1 = conv_layer(inp=1, out=8)
self.layer2 = conv_layer(inp=8, out=16)
self.layer3 = conv_layer(inp=16, out=32)
self.layer4 = conv_layer(inp=32, out=64)
self.layer5 = conv_layer(inp=64, out=128)
self.layer6 = conv_layer(inp=128, out=256)
self.layer7 = conv_layer(inp=256, out=512)

self.net = nn.Sequential(self.layer1, self.layer2, self.layer3, self.layer4, self.layer5, self.layer6, self.layer7)

self.fc1 = torch.nn.Linear(in_features=512, out_features=128)
self.bn1 = torch.nn.BatchNorm1d(128)

self.fc2 = torch.nn.Linear(in_features=128, out_features=32)
self.bn2 = torch.nn.BatchNorm1d(32)
self.fc3 = torch.nn.Linear(in_features=32, out_features=8)
self.bn3 = torch.nn.BatchNorm1d(8)
self.fc4 = torch.nn.Linear(in_features=8, out_features=4)

self.lin = torch.nn.Sequential(self.fc1, self.bn1, self.fc2, self.bn2,
self.fc3, self.bn3, self.fc4)

def forward(self, x):
x = self.net(x)
x = x.view(x.size(0), -1)
x = self.lin(x)
return x

5. And our CNN model architecture is complete.🎊 Time to grade it 👩‍🏫

One of our final steps is defining our loss and optimizer functions. We will be using cross-entropy loss and the Adam optimizer from the torch library. Adam has become the go-to optimizer as it is straightforward, fast, requires less memory, and less tuning than other optimization algorithms. Adam finds the gradients of the cost and updates weights with varying learning rates. 🧮

If you want to learn more about the Adam optimizer check out this article, and to learn more about the functions discussed above, check out my previous article that goes in-depth on the theory of CNN architecture. 🤿

criterion= nn.CrossEntropyLoss()
optimizer= optim.Adam(model.parameters(), lr=1.0E-3)

Training and testing 🏋️‍♂️

After training on a GPU for 20 epochs the model achieved more than 75% accuracy on the testing dataset. Considering this simple model was able to identify MRI scans without tumours with 100% accuracy provides tremendous hope for the medical imaging community. Imagine if doctors were able to utilize AI to help advise and efficiently classify scans in order to help patients in need. Not only would time be saved, but processes could be made cheaper and more accessible as well. 🌎 = 🙂

With PyTorch, building CNN models has become pretty simple. A lot of the mathematical work has been extracted for us. However, over the years new improvements have come to the ML community and we have a couple more options to test out to improve our accuracy.

Version 2 — Introducing transfer learning 👊

Transfer learning has taken over the world of deep learning as it has questioned one fundamental flaw of the current system. Why is everyone reinventing the wheel? In the case of CNN’s most image classifications will have the same initial weights and parameters. After all, most computer vision tasks must detect edges in early layers and simple shapes in the middle ones. 🔵 🔶 ⎺ 🟪 By using state-of-the-art models trained on large datasets, we can utilize their basic knowledge and fine-tune the model to work for our specific tasks.

  1. I similarly loaded all of the data in, but later downloaded and froze the parameters of the efficient net model. 🥅
model= models.resnet(pretrained=True)

for paramin model.parameters():
param.requires_grad=False

2. We will be keeping everything the same, except for the last layer of the model; the classifier. We will simply switch it out for fully connected layers that produce 4 outputs.

from collectionsimport OrderedDict
classifier= nn.Sequential(OrderedDict([
('fc1', nn.Linear(in_features=1280, out_features=540)),
('fc2', nn.Linear(in_features=540, out_features=4)),
]))
model.classifier= classifier

3. This time after training for 15 epochs we get close to a 90% accuracy. 👏

With just a couple of lines of code, we are on par with professionally trained radiologists. Transfer learning proves to boast great accuracies when less data is available and when models exist for similar tasks as the one at hand. 🤯

Feel free to check out all of the code in my GitHub repo.

Is AI the cure for our broken medical system?

Over 2/3 of people on earth do not have access to radiologists. Yes countries in Africa have no radiologists at all (14 African countries have none to be exact), and countries like India have radiologists in a ratio of 1 to 100,000, but even developed countries are being hit hard by this shortage. Canada and the UK for example have 30-day long wait times in order to get your medical images read.

Estimated number of radiologists per million inhabitants per country
Estimated number of radiologists per million inhabitants per country

But there is hope. AI is showing promising results for detection and classification in medical imaging. 🤖 + 👩‍⚕️ = 🏆

  • Niramai is a low-cost portable device capable of detecting breast cancer without the need for wifi or electricity. With the help of AI, women in developing countries can be diagnosed in the privacy of their homes.
  • Synapsica’s AI is helping radiologists automate tasks. Their software is analyzing MRIs in less than a minute.
  • Scientists at Université Lorraine and Tecnológico de Monterrey have demonstrated that AI can help identify hard-to-detect endoscopic kidney stones.

Closing thoughts 💭

Comparing the 75% accuracy of our model from scratch with the 90% accuracy with transfer learning, there is no doubt that we have a winner. With more tweaks to our architecture, both models can be improved, but while working with less data on a beginner level, it is always best to go with transfer learning techniques as they have repeatedly proved better results. 🚀

My name is Vinaya, a high school student passionate about using technology to create a better future. If you have any suggestions, questions, or just want to talk, you can message me on LinkedIn or Twitter. Feel free to subscribe to my monthly newsletter to stay updated on advancements in technology, opportunities and AI breakthroughs. Thank you for reading and I hope you learnt something new!

https://www.personalitynft.com/

--

--