Segmenting Brain Tumours from MRI Scans — 93% Accuracy!
Our vision is a crucial part of our lives. It enables us to do tasks, from coding and playing basketball to even reading this article. And we, humans, always thought that we were the only ones that could drive cars, interpret medical images in x-rays and MRIs, and understand the difference between different vehicles and animals.
BUT, we are no longer the only ones that can see, ever since computer vision came into the picture.
What is Computer Vision?
Computer vision is a subsection of Artificial Intelligence (AI) that gives computers the ability to derive meaningful information from digital images, videos, and other visual inputs. It can use this insight to take action or make recommendations.
Usually, in computer vision, a computer follows one of these tasks:
- Object Classification: What object category is in this photograph?
- Object Classification + Localization (Recognition): What objects are in this photograph and where are they?
- Object Detection: Where are the objects in the photograph?
- Semantic Segmentation: For each pixel in the image, what class label does it associate with? (i.e. does this pixel correspond to a cat → this accumulates to a segmented image between labelled classes)
When I learned about semantic segmentation (also referred to as segmentation), I was really eager to do a segmentation project and I was thinking: what images should the model segment??? 🤔 A few days later, when doing my biology homework, which was about different types of tumours, I came across facts about brain tumours. I researched more into brain tumours and it turns out that, today, an estimated 700,000 people in the United States are living with a primary brain tumour, and approximately 88,970 more will be diagnosed in 2022.
With more research, I found out that a complication with brain tumours is that detecting and segmenting brain tumours in Magnetic Resonance Images (MRIs) are crucial but time-consuming tasks performed by medical experts. Especially with segmentation, the task is challenging and error-prone due to the irregular form and confusing boundaries of a brain tumour. On top of that, MRI systems are not flawless and can sometimes miss out on important details in an MRI scan. As a result, if neurosurgeons generally don’t know exactly where the brain tumour is and fails to remove the entire tumour (to the extent that it isn’t damaging vital brain tissue), it can be detrimental for the patient.
Some of the bolded words in the above paragraph, specifically “detecting” and “segmenting”, seem like something computer vision could solve, don’t you think??? That’s when I had an “aha” moment💡about using computer vision to detect different types of brain tumours and to segment brain tumours from MRI scans.
Note: despite research into computer scanning, medical experts still read scans manually because computer vision technology is still emerging and, before implementing in a hospital/clinic setting, there is a clinical trial process, which can be lengthy.
Brain Tumour Classification Project
I started off with a project that can determine if there is or isn’t a brain tumour in the MRI scan. If there is a brain tumour, the model is also further able to identify what type of brain tumour the MRI scan is showing — a glioma, meningioma, or pituitary tumour.
The Github repository for this project includes a Jupyter notebook and an example image of the visualization (the output results displayed visually) from my classification model: https://github.com/shizacharania/Brain-Tumour-Classification
However, in this article, I will be discussing my Brain MRI Segmentation Project. About a month after working on this project, my model with 93% accuracy can segment brain tumours from brain MRI scans.
In my Github repository, I have 3 different Jupyter notebooks: one with dice loss, one with cross-entropy loss, and one that has my code with extra comments about the errors I’ve hurdled through. I also included an example image of the visualization (the output results displayed visually) from my code.
The code I will be explaining is in small chunks. However, if you’d like to see what my whole project looks like and use/refer to it, check out my Github repository for this project:
About the Dataset
Before building out the code, a crucial component I had to consider was the availability of a brain tumour segmentation dataset. Fortunately, I found a fairly reliable one on Kaggle that is open-source, meaning that I have permission from the dataset’s owner to use it:
This open-source dataset has a total of 7858 .tif files (.tif is a type of image format, like .jpg or .png). Out of these images, most of them are MRI scans of lower-grade gliomas, which is a certain type of brain tumour. However, there are also images that either has no brain scans in the .tif file or there are scans, but they don’t have a tumour present. Along with the images are the masks, which are the 100% accurately segmented tumours of these scans. Out of the 7858 .tif files, there are 3929 images and 3929 masks. Also, note that the masks are what the AI model will train on to increase its accuracy (just like one would train to shoot free throws to get better at them).
Now that we have covered more details about the dataset I used, let’s dive deeper into the code 🚀
My code is split up into 12 segments:
- Importing the Dataset
- Connecting the Kaggle API + downloading and unzipping the data - Modifying and Visualizing the Dataset’s Composition
- Importing + getting the images and masks’ file paths + visualizing the file paths in a data frame + visual comparison between the number of tumours and non-tumours - Preprocessing the Dataset
- Importing + normalization + applying normalization and other transformations - Splitting Up the Data
- Using the 60–20–20 rule to split up the data into training, validation, and testing datasets - Data Augmentation
- Performing a 90 degree clockwise rotation, horizontal flip, or vertical flip on 1/3 of the images in the training dataset + adding these images to the current training dataset - Making Data Loaders
- Zipping the data into a list + declaring the 3 different data loaders - Visualizing the Images and Masks from the Different Data Loaders
- Iterating through the data loaders + reading and plotting the images and masks + arranging a sample of the images into a clear format - Creating the model
- About the UNET architecture + using classes to build different blocks of the UNET model from scratch - Picking a Loss and Optimizer
- The reasoning behind choosing the Cross-Entropy Loss, the Dice Loss, and the Adam optimizer + making the Dice Loss function from scratch - Training and Validation
- Putting the training and validation images through the model, deriving the loss, performing backpropagation + saving the model (for validation only) - Testing the Model
- Calculating the accuracy and testing loss from the testing data - Visualizing the Results
- Using the testing dataset’s output and visualizing the predictions it made + adding a threshold to those predicted masks
Step 1: Importing the Dataset
To use the dataset from Kaggle, we need to use the Kaggle API and download the dataset.
Connecting the Kaggle API
First, we install the Kaggle package using !pip install
. 🗂
!pip install -q kaggle
We need to go to Kaggle and download a .json file that contains your generated Kaggle API token. From there, we can upload this file from our local system (the file is usually called kaggle.json). Since I used Google Colab for my code, I easily imported files
and then uploaded my .json file.
from google.colab import files
files.upload()
Next, we create a Kaggle directory on Google Colab using !mkdir
, put the kaggle.json file in that directory using !cp
, and grant permission for the .json file to act using !chmod
. The reason for the !
is because that is how the format of a command should start off.
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
By this stage, the Kaggle API should have worked 🔗, meaning that Kaggle is linked to the Colab notebook. To ensure this is true, we can list some of the available datasets on Kaggle using this line:
!kaggle datasets list
Downloading and unzipping the data
If the code above runs and you get a shortlist of datasets in Kaggle, you can move on to the next part, which is downloading the dataset. You do this by first typing !
and then copying the API command from the Kaggle dataset, which in this case is kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
. Together, this is what the line of code should look for downloading the dataset:
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
Currently, the dataset is in a .zip file and there are images and masks inside this file. Thus, in order to directly access it, we need to unzip the images by using !unzip
followed by the name of the .zip file.
!unzip lgg-mri-segmentation.zip
Step 2: Modifying and Visualizing the Dataset’s Composition
A useful action we can do at this point is to see what the dataset consists of (i.e. the number of images with and without tumours).
Importing
First of all, we need to import some libraries that will be useful to us in this step and other ones as well. We need to import
:
matplotlib.pyplot as plt
since we need to graph results and images in numerous scenariosnumpy as np
due to its valuable operationsglob
for getting the path of all the images in the datasetcv2
for getting specific data about the image, such as the diagnosis (if it has a tumour or not) and reading the imagepandas as pd
for viewing all the file paths in a data frame
import matplotlib.pyplot as plt
import numpy as np
import glob
import cv2
import pandas as pd
To help comprehend the process of getting all of the file paths of the images and masks, these are two examples of file paths:
Getting the images and masks’ file paths
To get all the file paths, we need to have a root_path
that all the mask and image file paths contain.
root_path = '/content/lgg-mri-segmentation/kaggle_3m/'
Let’s start off with the mask file paths, where we can just use glob.glob()
.
We need to put root_path + "*/*_mask*"
in the brackets, because we are looking for file paths that start with the root_path
then have */*
. The *
represents some more paths that lead up to the keyword _mask
. After this, we add another *
, representing the “.tif” ending in the file path. Whatever file paths follow this format are returned to potential_mask_files
.
potential_mask_files = glob.glob(root_path + "*/*_mask*")
The problem with the dataset is that so many files don’t have a tumour and since the model that is trying to segment tumours, we need to remove some of the masks’ file paths (so, its not the final list of file paths). Note that doing this is a small part of preprocessing the dataset as well.
Here, we can use a for
loop iterating through every file path in potential_mask_files
. In the loop, we read the image using cv2.imread()
, which returns a NumPy array with the RGB value of each pixel. Then, we return the max value of all the pixels using np.max()
, which will either be 0 (if there is no tumour) or 255 (if there is a tumour). If it’s > 0, then we add it to mask_files
, which contains all the “final” file paths for the masks.
Note that whenever we want to add a single item to a list, we can use Python’s .append()
function.
We can then use this information along with add_count
(counter for the number of images we iterate over) to add a third of the images without a tumour to mask_files as well.
mask_files = []
add_count = 0for mask in potential_mask_files:
if np.max(cv2.imread(mask)) > 0:
mask_files.append(mask) elif np.max(cv2.imread(mask)) == 0 and add_count % 3 == 0:
mask_files.append(mask)
add_count += 1
Now that we have our mask files, we can also get our image files. A shortcut to do this is iterating through all the masks in mask_files
and replacing "_mask"
with ""
(basically removing “_mask”). This works because, as shown before, the difference between the file paths of an image and mask is that the masks have an addition "_mask"
. Therefore, by removing that part adding it to the image_files
, we can get all the “final” file paths for the images.
image_files = []
for mask in mask_files:
rmask = mask.replace("_mask", "")
image_files.append(rmask)
Visualizing the file paths in a data frame 👓
After getting the file paths of the images and masks, visualizing them becomes fairly simple because we use np.max(cv2.imread())
to append "1"
to the tumour_count
list if there is a tumour or append "0"
if there isn’t one.
We can also create a data frame with these image files, masks, and diagnoses data using pd.DataFrame()
.
tumour_count = []def diagnosis(mask_path):
if np.max(cv2.imread(mask_path)) > 0:
tumour_count.append("1")
return 1
else:
tumour_count.append("0")
return 0
files_df = pd.DataFrame({"image_path": image_files,
"mask_path": mask_files,
"diagnosis": [diagnosis(x) for x in mask_files]})
print(files_df)
Visual comparison between the number of tumours and non-tumours
To visualize the number of tumours and non-tumours in the modified dataset of the images and masks, we first need to count the number of 0s and 1s in tumour_count
(we declared this variable above) and assign the amount to n_tumours
and n_nontumours
, respectively. We can input this information into plt.bar()
to create a bar graph comparing these categories.
n_tumours = tumour_count.count("1")
n_nontumours = tumour_count.count("0")plt.bar(["Tumours - " + str(n_tumours), "Non-Tumours - " + str(n_nontumours)], [n_tumours, n_nontumours], color=["green", "red"])
Step 3: Preprocessing the dataset
Although the dataset is downloaded properly, the data still might not be clean. Usually, custom datasets have irregularities, such as different image sizes, many of the images being fully black ⬛️ , different ranges for the pixel values, etc. In order to tackle this issue with our dataset, we need to preprocess.
Importing
Similar to Step 1, we need to import some libraries that will be useful to us in this step and other ones as well. We need to import
:
torch
, which is used to develop and train neural networkstorchvision
because of the various functions and classes it incorporates, such as transformations for the data (we’ll be getting into this soon) and making the AI model from scratch- As mentioned in the point above, we need different classes of transformations for the images, so we can type
from torchvision.transforms import transforms
Normalization
One very important step within preprocessing is normalization, which in simple words, is a process that makes something more normal or regular. Applying this to neural networks, normalizing the dataset involves rescaling the data to make it easier and faster to train. We can do this by using the normalization formula, which involves subtracting the mean of the data and then dividing that by the standard deviation.
The mean
of the data in this context is the sum of all the pixel values of every image divided by the total number of pixels in all the images.
num_pixels = len(image_files)*256*256
total_sum = 0
for data in image_files:
image = cv2.imread(data)
image = torch.from_numpy(image)
total_sum += image[0].sum()mean = total_sum/num_pixels
The standard deviation in this context is essentially just following the formula, where we derive the difference between all the elements and the mean, square those differences using **2
, and then .sum()
up all the differences; this is called the sum_squared_error
. Next, we divide the sum_squared_error
by the total number of pixels and square root the quotient using torch.sqrt()
.
sum_squared_error = 0for data in image_files:
image = cv2.imread(data)
image = torch.from_numpy(image)
sum_squared_error += ((image[0] - mean)**2).sum()std = torch.sqrt(sum_squared_error/num_pixels)
After calculating the mean and standard deviation and printing them out, my mean and standard deviation was usually around 0.0215 and 1.1606 (“usually” because every time I calculated those two values, I got similar — not exactly the same — answers).
Next, we need to actually apply the normalization to the dataset along with other transformations and actions as well. For this part, I created a function called loading_data
which inputs:
files
— the file paths we want to add to the datasetdataset
— an empty list defined above the function that we will append the preprocessed images toactions
— set as True if we are inputting images of brain MRI scans (I’ll refer to these as “images”) and False if we are inputting masks; the reason for this is that the transformations vary on whether we are preprocessing an image or a mask.
Applying normalization and other transformations
For every file path that is inputted in the function, we read it using cv2.imread()
, returning a NumPy array (we’ve done this in a previous step too). An important line to remember is resizing the image to 96 x 96 using cv2.resize()
; the reason for its importance is because, after days of trying to figure out why my model took 11 hours to train 11 epochs (an epoch is one cycle of training the model with all the training data), decreasing the image size made a HUGE difference since it went down to about 7 minutes for every epoch 😅.
After reading and resizing, we can apply the transformations with torchvision.transforms.Compose()
that contains a list of all the transforms we want to apply to the images/masks.
For images (when actions
is set to True), we:
- Transform the NumPy array of every image or mask that we read into a Torch tensor —
transforms.ToTensor()
- Cast the int8 values of the pixels to float32 —
transforms.ConvertImageDtype()
- Normalize the images with the mean and standard deviation values we calculated earlier —
transforms.Normalize()
- Then, we need to input the image into the transform variable that holds all the transformations —
image = transforms(image)
— and store it into the original image variable (to replace it).
The transformations applied to the masks (when actions is set to False) is fairly the same, except we don’t normalize them. Instead, we make them grayscaled using transforms.Grayscale()
← this changes the number of colour channels of the mask from 3 (RGB) to 1. Similar to the images, we use image = transforms(image)
.
Once the for
loop of going through all the images/masks is complete, we append them to the dataset inputted into the loading_data
function (either image_data
or mask_data
).
image_data = []
mask_data = []def loading_data(files, dataset, actions):
for data in files:
image = cv2.imread(data)
image = cv2.resize(image, (96, 96))
if actions == True:
transform = torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize((mean, mean, mean),
(std, std, std))])
image = transform(image) elif actions == False:
transform = torchvision.transforms.Compose([
transforms.ToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Grayscale()])
image = transform(image) dataset.append(image)
Once we define the loading_data
function and all the actions we need to take within it, we can call the function for all the image and mask files.
loading_data(image_files, image_data, True)loading_data(mask_files, mask_data, False)
Step 4: Splitting Up the Data
When working with the model, the data should be split up into:
- The training dataset, which is the data used to train the model (as the name suggests 😂) + it teaches the model how to perform the desired task
- The validation set is used for determining the parameters of the model that lead to a lower loss/higher accuracy and saving them + it aims to avoid over-fitting (when the model fits too well against the training data, so it can’t perform accurately with other data)
- The testing set is used for evaluating the final performance of the model in an unseen dataset
Before splitting the data , let’s check the length of the dataset using the len()
function in Python — this can help us later determine if the splitting was performed correctly.
print(len(image_data), len(mask_data))
This prints out 2219 and 2219 since the length of the dataset is the same for the image and mask data.
To split the dataset, we should follow the 60–20–20 rule, which means that the training data should have 60% of all the images and masks in the dataset and the validation and training data should have 20% each. We can do this by deriving the length that each dataset will be and then use list slicing to get the split data.
len_trdataset = int(np.floor(len(image_data) * 0.60))
len_vtdataset = (len(image_data)-len_trdataset)//2trimage_data = image_data[:len_trdataset]
trmask_data = mask_data[:len_trdataset]vimage_data = image_data[len_trdataset:(len_trdataset+len_vtdataset)]
vmask_data = mask_data[len_trdataset:(len_trdataset+len_vtdataset)]testimage_data = image_data[(len_trdataset+len_vtdataset):]
testmask_data = mask_data[(len_trdataset+len_vtdataset):]
Now that we split up the data, it’s helpful to verify that the number of images/masks in the data are correct in proportions.
print(str(len(trimage_data)), "----", str(len(vimage_data)), "----", str(len(testimage_data)))
This prints out: 1331 ---- 444 ---- 444
, meaning that the split was successful!
Step 5: Data Augmentation
Data augmentation involves increasing the amount of data 📈 by adding slightly modified copies of already existing data. This step is optional, but it can be very beneficial to get high accuracy.
For this step, I have defined a function called data_augmentation
. In the function, we go through every image
in the dataset
inputted into the function. For every case, we use torchvision.transforms
we imported earlier. Then, we input that image into the transform variable we defined for that particular image’s transformation (e.x. transform0
). Finally, we .append()
the image to the current dataset.
There are 3 types of transformed images that can be added to the dataset, depending on where it’s located in the dataset (the variable count
is used for this, making it easier to augment a certain amount of modified images):
- If
count
is divisible by 9, then we rotate the image 90 degrees clockwise usingtransforms.RandomRotation(degrees=[90,90])
🔁 - If
count
divides by 9 and has a remainder of 1, then we flip the image vertically usingtransforms.RandomVerticalFlip()
➡️ ⬅️ - If
count
divides by 9 and has a remainder of 2, then we flip the image horizontally usingtransforms.RandomHorizontalFlip()
⬇️ ⬆️
Or, if count
doesn’t follow any of these statements (which accounts for 2/3 of the dataset), nothing is done to the current dataset.
In the end, we add 1 to the value of count, which helps keep track of the position of the image in the dataset and fulfils its role 😇.
def data_augmentation(dataset):
count = 0
for image in dataset:
if count % 9 == 0:
transform0 = transforms.RandomRotation(degrees=[90,90])
image = transform0(image)
dataset.append(image)
if count % 9 == 1:
transform1 = transforms.RandomVerticalFlip(p=1)
image = transform1(image)
dataset.append(image)
elif count % 9 == 2:
transform2 = transforms.RandomHorizontalFlip(p=1)
image = transform2(image)
dataset.append(image)
count += 1
Now, we need to call this function with the image and mask data belonging to the training dataset. This is because data augmentation is only applied to the training dataset, not to the validation or testing dataset.
data_augmentation(trimage_data)
data_augmentation(trmask_data)
We can check if the data augmentation worked only if the number of images and masks increased after this step. In other words, we can print out the len()
of trimage_data
and trmask_data
before and after data augmentation using:
print(len(trimage_data), len(trmask_data))
In my case, comparing the number of training images and masks from before and now, it increased from 1328 to 1997.
Step 6: Making Data Loaders
When we want to retrieve or iterate over the data, it becomes a lot easier to store it in one variable. Thankfully, the torchvision
library wraps an iterable around the dataset to enable easy access to batches; the batch size is the number of samples that are passed to the network at once.
Usually, the total number of images in a dataset is divisible by the batch size, but in our code, that’s not the case. Our last epoch has a smaller amount of images passed through, but this doesn’t cause a huge problem.
Zipping the data into a list
Before even making a data loader, we need to zip the images and masks into a list for each dataset (training, validation, and testing). This is fairly simple when using zip()
and list()
, which are Python functions.
training_data = list(zip(trimage_data, trmask_data))validation_data = list(zip(vimage_data, vmask_data))testing_data = list(zip(testimage_data, testmask_data))
To visualize that segment of code, this is a piece of code I quickly wrote up in a Python IDE. As you can see, if you input two lists into list(zip())
, it returns a list containing tuples of the same-index elements in both lists.
Declaring the trainloader, validationloader, and testingloader
Now we can make a data loader. First, we need to import utils
from torchvision
, which can simplify this step.
from torchvision import utils
Next, we can declare the variables for the trainloader
, validationloader
, and testingloader
. In torch.utils.data.DataLoader()
, we input the parameters:
training_data
,validation_data
, ortesting_data
(the list we defined before importing utils), depending on the corresponding data loaderbatch_size
, which is set as 64 since it’s a typical size for batchesshuffle=True
, meaning that we want to shuffle the dataset every time we run the data
trainloader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)validationloader = torch.utils.data.DataLoader(validation_data, batch_size=64, shuffle=True)testingloader = torch.utils.data.DataLoader(testing_data, batch_size=64, shuffle=True)
Step 7: Visualizing the Images and Masks from the Different Data Loaders
We have already declared a train, validation, and testing loader, but let’s see if it works; we can do this by iterating through the loaders and reading + plotting the image.
Iterating through the data loaders
By iterating over the particlular data loader with iter()
and then using next()
, we can get all the images and masks in that data loader.
trimages, trmasks = next(iter(trainloader))
Reading and plotting the images and masks
We need to change up the dimensions of the image before using plt.imshow()
, which requires image data with any of these supported shapes:
- (M, N): an image with scalar data — the first two dimensions (M, N) define the rows and columns of the image
- (M, N, 3): an image with RGB values (0–1 float or 0–255 int)
- (M, N, 4): an image with RGBA values (0–1 float or 0–255 int), where A represents the transparency
Thus, by permuting the images and masks, we change the dimension order from [64, 3, 96, 96] to [64, 96, 96, 3] for images. Since masks are grayscaled, their dimension order is changed from [64, 1, 96, 96] to [64, 96, 96, 1]. Images have the correct dimension size now, but masks don’t follow any of the supported shapes because of the “1” in their shape (calculated by .shape()
). So, we can .squeeze()
this “1” and make the dimensions for the masks [64, 96, 96].
Note: the “64” in the images and masks’ shape dimensions represent the batch size
trimages = trimages.permute((0,2,3,1))
trmasks = trmasks.permute((0,2,3,1))trmasks = trmasks.squeeze()print("TRAIN:")
print(trimages.shape)
print(trmasks.shape)
When we show an image or a mask, we can just retrieve an element from each list. After every plt.imshow()
, when we call plt.show()
, it shows what the image/masks looks like on a 96 x 96 pixels plot.
print("Images:")
plt.imshow(trimages[3])
plt.show()
print("Masks:")
plt.imshow(trmasks[3])
plt.show()
Arranging a sample of the images and masks into a clear format
After experimenting with other features of plt
, we can display these images and masks through visualization (this part is completely optional, but makes the data more organized and visually clear).
def show_aug(loader, nrows=4, ncols=10):
images, masks = next(iter(loader))
images = images.permute((0,2,3,1))
masks = masks.permute((0,2,3,1))
masks = masks.squeeze()
plt.figure(figsize=(20, 20))
for i in range(len(images[:10])):
plt.subplot(nrows, ncols, i+1)
plt.imshow(images[i])
plt.axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
plt.figure(figsize=(20, 20))
for j in range(len(masks[:10])):
plt.subplot(nrows, ncols, j+1)
plt.imshow(masks[j])
plt.axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
show_aug(trainloader)
Step 8: Creating the Model
This part of the code, in my opinion, is the most crucial to this project. The model for segmentation uses a complex architecture called U-Net. To put it simply, the way the model works is an encoder path (left-hand side) and decoder path (right-hand side). You can import the model from Github with a simple 3 lines, but I decided to make the model FROM SCRATCH with about 75 lines 😎.
About the UNET architecture
I will get a bit more technical in this part about different features of the model, so I would highly recommend that you read this website beforehand that explains how convolutional and pooling layers work:
Back to the U-net model, the encoder path ↘ is generally used to gain context about the image. It involves:
- Two 3x3 convolutional layers where stride and padding are both set to one. This layer increases the number of channels and identifies patterns in the images —
nn.Conv2d()
- After the convolutional layer, we can use batch normalization (think normalization but for batches) → makes the model faster and more stable —
nn.BatchNorm2d()
- Then, we use the ReLU() activation function, which helps comprehend intricate patterns —
nn.ReLU()
- After the activation function, we use a pooling layer called the max pool layer; this layer downsamples the image size and decreases the number of computations necessary, saving a lot of time —
nn.MaxPool2d()
In the encoder path, the block of the convolutional layer + batch normalization + ReLU activation function + max pool layer is repeated a few times.
Then, we get into the decoder path, which in many ways, is opposite to the encoder path — it is used to enable precise localization.
- First, in terms of the pooling layer, this time, we are using something much different than max-pooling layers: the
torch.nn.functional.interpolate()
function to upsample the image size/increase the height and width - Another different aspect is that we are using skip connections to concatenate information from past layers in the encoder path to the current layers in the decoder path. The reason for this is to assemble a more precise output —
torch.cat()
- Next, similar to the encoder path, there are still 2 convolutional layers that have a 3x3 kernel and the same stride and padding (1), but we are decreasing the number of channels this time
In the decoder path ↗, this block of the upsampling pooling layer + concatenation + convolutional layer + batch normalization + ReLU activation function is repeated a few times.
From the encoder and decoder path, we should usually end up with an image of the same size, but the image should start off with 3 channels and end up with 1 (since the output is greyscaled).
Using classes to build different blocks of the UNET model
There are many ways to create the UNET model, but after experimenting with these different ways, I found it efficient to, first, create classes for the different types of blocks in the model:
- A class for one convolutional layer (
ConvBlock
) - A class of the stack for the encoder path (
StackEncoder
) - A class of the stack for the decoder path (
StackDecoder
)
Each class contains an __init__
method, where we need to define the variables we’ll need to perform operations on the input. Then, we need a forward
function, where there is a progression of actions performed on the input (known as “x” here).
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), stride=1, padding=1)
self.batchnorm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU() def forward(self, x):
x = self.conv(x)
x = self.batchnorm(x)
x = self.relu(x)
return x
class StackEncoder(nn.Module):
def __init__(self, in_channels, out_channels):
super(StackEncoder, self).__init__()
self.max_pool = nn.MaxPool2d(kernel_size=(2,2), stride=2)
self.block = nn.Sequential(
ConvBlock(in_channels, out_channels),
ConvBlock(out_channels, out_channels)) def forward(self, x):
block_out = self.block(x)
pool_out = self.max_pool(block_out)
return block_out, pool_out
class StackDecoder(nn.Module):
def __init__(self, in_channels, out_channels):
super(StackDecoder, self).__init__()
self.block = nn.Sequential(
ConvBlock(in_channels+in_channels, out_channels),
ConvBlock(out_channels, out_channels)) def forward(self, x, concat_tensor):
batch, channels, height, width = concat_tensor.shape
x = torch.nn.functional.interpolate(x, size=(height, width))
x = torch.cat([x, concat_tensor], 1)
blockout = self.block(x)
return blockout
Then, we have to incorporate all of the above classes into one class for the actual U-Net model. The class gets the input shape of the image and splits the parameters up into batch size, channel, height, and width. Then, it goes down the encoder path and to the bottleneck. You can consider the bottleneck as part of the encoder path, but it’s essentially one convolutional layer + batch normalization + ReLU activation function. This leads to the decoder path.
I included comments on how the shape of the image changes over time to help understand the transformations. The images in the dataset should have a shape that goes from 96 x 96 x 3 to 96 x 96 x 1.
class UNET(nn.Module):
def __init__(self, input_shape):
super(UNET, self).__init__()
self.batch, self.channel, self.height, self.width = input_shape # 96 x 96 x 3
self.down1 = StackEncoder(self.channel, 64) # 48 x 48 x 64
self.down2 = StackEncoder(64, 128) # 24 x 24 x 128
self.down3 = StackEncoder(128, 256) # 12 x 12 x 256
self.bottleneck = ConvBlock(256, 256) # 12 x 12 x 256
self.up3 = StackDecoder(256, 128) # 24 x 24 x 128
self.up2 = StackDecoder(128, 64) # 48 x 48 x 64
self.up1 = StackDecoder(64, 1) # 96 x 96 x 1
def forward(self, x):
down1, out = self.down1(x)
down2, out = self.down2(out)
down3, out = self.down3(out)
bottleneck = self.bottleneck(out)
up3 = self.up3(x=bottleneck, concat_tensor=down3)
up2 = self.up2(x=up3, concat_tensor=down2)
up1 = self.up1(x=up2, concat_tensor=down1) return up1
To check if the model’s layers work, we can first create a tensor with the shape 1 x 3 x 96 x 96 using torch.rand()
. Then, we need to create a model
variable for the tensor to go through. If we compare the input and output shape, we can view the differences to understand if the layers worked. We can also print out the model, but it is quite long, so I won’t display it in this article.
inp = torch.rand(1,3,96,96)
print(inp.shape)model = UNET(inp.shape)
print(model)out = model(inp)
print(out.shape)
Step 9: Picking a Loss and Optimizer
When training, validating, and testing the model, we need to have a loss function to calculate how off the predicted mask is from the actual mask. We also need an optimizer for the training that will help change up parameters in the model (such as weights and biases).
Reasoning for choosing Cross-Entropy and Dice Loss
For image segmentation, according to a research paper 📑 I read named “A survey of loss functions for semantic segmentation”, some of the most common loss functions for semantic segmentation are cross-entropy loss and dice loss.
The cross-entropy loss (in the first image below) calculates the difference between two probability distributions.
On the other hand, the dice coefficient function finds the area of overall (intersection ⿻) between the predicted mask and actual mask multiplied by 2; it then divides that by the total of the actual mask’s area, the predicted mask’s area, and the smooth (e in the below equation). We use smooth to avoid division by zero when both the predicted and actual pixel are completely black (value of 0).
As you may have noticed by the bolded word “coefficient”, the dice coefficient is not the same as dice loss. Thus, in situations where a particular metric, like the dice coefficient, is being used to judge model performance, usually loss functions that derive from these metrics are in the form of 1 - f(x)
where f(x)
is the metric in question. Applying the formula to this situation, the dice loss would be 1 minus the dice coefficient.
In my Github repository, there are a few files, including 2 different jupyter notebooks that each use a different loss function (one uses the cross-entropy loss function and one uses the dice loss function). The cross-entropy loss was very easy to implement because the loss function was already available in the torch.nn module and we can just store it in the criterion
variable.
criterion = torch.nn.CrossEntropyLoss()
Making the Dice Loss function from scratch
For the dice loss, I made a class that calculates it. Firstly, the class flattens the input images and target images (the masks in this case) from 2D tensors to 1D vectors using .view(-1)
. Then, we find the intersection
by multiplying each pixel in the input and target image and adding the .sum()
of all the pixels in the images ← this happens for all the pairs of input and target images. Next, we input the intersection into the rest of the formula for dice_loss
.
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets):
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice_loss = 1 - (2*intersection + 1) / (torch.sum(inputs) + torch.sum(targets) + 1)
return dice_loss
We can call this DiceLoss
class and store it in the variable criterion
.
criterion = DiceLoss()
The Adam optimizer
For the optimizer, I used the Adam optimizer since it’s known as the best one for CNNs. Before we do this, it’s easier to import optim
from the torch
package, since we can easily pick the optimizer from there.
from torch import optim
With optim.Adam()
as our optimizer, we need to input the model’s parameters since that’s what we’re optimizing and we have to set a learning rate.
Think about the learning rate as if you’re trying to learn biology terms for a unit test; if you learn all the terms really fast, you will probably forget a lot of them when you do the test. On the other hand, learning it slow and steady 🐢 will be a lot more beneficial → this is why procrastination is a bad habit. In this case, we set the learning rate (lr
) to 0.0001, which is pretty low and worked pretty well during training, validation, and testing (as you’ll see).
optimizer = optim.Adam(model.parameters(), lr=0.0001)
Step 10: Training and Validation
During training and validation (and testing), we define a loss.
To easily conceptualize “loss”, think of training/validating the model as trying to shoot a 3-pointer for the first time (you obviously won’t be like Steph Curry and get it in every time 😉). Also, to measure how far off you were from getting it in, you use the shortest distance between the net and the ball → the loss. You’d probably airball the first time and realize that the loss was large. Then, you change your technique, strive for a lower loss, and continue this cycle again and again until you get fairly decent accuracy.
To store the losses, we will have two lists — one for training (trlosses
) and one for validation (vlosses
).
trlosses = []
vlosses = []
I mentioned an epoch before, but to recall, an epoch is one cycle of forward passing and performing backpropagation on the model with the data (we’ll look into these terms soon). I set the number of epochs
to 11 since numerous resources state that’s the optimal number of epochs .
epochs = 11
We use a for
loop to go through every epoch, in which training and validation is performed. Something that helped keep track of the time elapsed ⌚️ after each epoch is the datetime
library, which we can import. In this library, datetime.datetime.now()
gets the time in that exact moment; consequently, we can do this at the beginning and end of each epoch and subtract that to derive the calculation.
Also, before starting the training and validation, we can declare the variable for storing the loss in each epoch.
import datetimefor epoch in range(epochs):
startepoch = datetime.datetime.now() training_loss = 0
validation_loss = 0 # training # validation # some more necessary actions to perform endepoch = datetime.datetime.now()
print("Epoch time:", str(endepoch-startepoch), "\n")
To train, we need to first tell our model that we are training it (because we can’t hide secrets from it — just kidding 😝) using model.train()
. Then, we have a for
loop going through all the images and masks in the trainloader
and storing them into variables.
But, to get to the output images, there are two important actions to take:
- Change the size of all the images before putting it into the model. I did this because by getting rid of the “1” in [1, 10, 96, 96], which represents the number of channels, there are fewer parameters to work with; this helps reduce the time it took for each epoch to run. Note that since we reduced the dimensions in the actual mask images, we need to reduce them for the output images (predicted mask images) as well.
- Use
optimizer.zero_grad()
to set the gradients of the model to 0, which represents the change in all of the model parameters from the Adam optimizer. Clearing the gradients from each prior epoch is needed, so the results don’t “collide” 🚗 with each other. - Then, we can put the images into the
model()
to get the output images ← we call this the “forward pass”.
Now that the output images and masks have the same dimensions [batch_size, width, height], we can use the criterion
to compare them and derive the loss (whether it is the cross-entropy loss or dice loss).
The next part of this step is to use the backward()
function and apply it to the loss, which runs through the model backwards (backpropagation).
Furthermore, we use optimizer.step()
to update the weights when performing backpropagation → called gradient descent.
We also need to add the loss.item()
, which is the total loss of the entire mini-batch divided by the batch size, to the training_loss
variable outside the training for
loop.
model.train()
for images, masks in trainloader:
images, masks = next(iter(trainloader))
masks = masks.squeeze()
optimizer.zero_grad()
output = model(images)
output = output.squeeze(1)
loss = criterion(output, masks)
loss.backward()
optimizer.step()
training_loss += loss.item()
The validation is quite similar since we’re iterating through the validation loader, doing the forward pass, squeezing the masks and outputs, finding the loss, and then adding loss.item()
to the validation_loss
. One major difference is that we tell the model we are at the validation step using model.eval()
instead of model.train()
, which was for training.
model.eval()
for images, masks in validationloader:
images, masks = next(iter(validationloader))
masks = masks.squeeze()
out = model(images)
out = out.squeeze(1)
loss = criterion(out, masks)
validation_loss += loss.item()
After the training and validation loop, there are some more necessary actions to perform:
- We need to find the mean loss of all the images into the train loader and validation loader and then add it to the training losses and validation losses list, respectively.
- We need to print the epoch we are on along with the training loss and the validation for that epoch
- If the validation loss is the minimum from
validation_loss
, we print out a statement that says the overall validation loss has decreased and then saves the model’s parameters to a checkpoint calledbrainmrisegmentation.pth
. A checkpoint helps with early stopping, a technique that aims to avoid over-fitting (when the model fits too well against the training data, so it can’t perform accurately with other unseen data).
mean_tloss = training_loss/(len(trainloader))
mean_vloss = validation_loss/len(validationloader)
trlosses.append(mean_tloss)
vlosses.append(mean_vloss)
print("Epoch: {} ...".format(epoch+1), "Training Loss: {:.4f} ...".format(trlosses[-1]), "Validation Loss: {:.4f} ...".format(vlosses[-1]))
if vlosses[-1] <= min(vlosses):
print("Validation Loss has decreased - saving")
torch.save(model.state_dict(), "brainmrisegmentation.pth")
The below picture shows the output after 11 epochs (each about 6.5–7mins). As you can see, the training loss and validation loss decreased and parameters of the model in some epochs were saved as well.
We can visualize this easily by using plt.plot()
with the parameters being the list of losses and its label on the legend (which doesn’t have a frame around it —frameon=False
).
plt.plot(trlosses, label="Training loss")
plt.plot(vlosses, label="Validation loss")
plt.legend(frameon=False)
Step 11: Testing
When testing the model, we need to test the “saved” model, which is stored in brainmrisegmentation.pth
. We can do this by loading the saved model and its parameters using torch.load()
; we also need to update the current model with these saved parameters, which can be done with load_state_dict()
.
loaded_model = torch.load("brainmrisegmentation.pth")
model.load_state_dict(loaded_model)
When running the above code, it should print out “<All keys matched successfully>”.
Now, to test the saved model, we need to declare some variables:
epochs
— set to 1 epochtesting_loss
— stores the testing loss after the epochcorrect_pixels
— represents the number of pixels in all the images that were correctly identified as containing a tumourtotal_pixels
— the total amount of pixels in all the images
epochs = 1
testing_loss = 0
correct_pixels = 0
total_pixels = 0
The first part of the epoch (until the line where we add the loss to testing_loss
) follows the same pattern as training and validation. After running this code up to this particular line, this is what the actual mask and predicted mask look like, respectively.
As you can see from the picture above, the masks are quite similar, but there are some faint pixels circled in white. These pixels aren’t very useful to us, so we can remove them by turning all values of the pixels from 0.0–0.5 to 0.5–1.0 (since pixel values of a tensor image range from 0.0 to 1.0) — this action is stored in preds
.
Now, we need to calculate the amount of correct and total pixels in all the images, which will help us find the accuracy:
- The correct pixels are calculated by checking the number of times the prediction and mask pixels are matching in the image — the times this instance happens is calculated with
.sum()
and is added to thecorrect_pixels
variable - The total pixels are returned using
.numel()
and are added to thetotal_pixels
variable
for images, masks in testingloader:
images, masks = next(iter(testingloader))
masks = masks.squeeze()
out = model(images)
out = out.squeeze(1)
loss = criterion(out, masks)
testing_loss += loss.item()
preds = (out > 0.5).float()
correct_pixels += (preds == masks).sum()
total_pixels += torch.numel(preds)
Then, after the epoch, we need to calculate the accuracy — it can be derived from the correct amount of pixels divided by the total amount of pixels in the testing images. We can then print this out.
accuracy = correct_pixels / total_pixelsprint("Got {}/".format(correct_pixels) + "{}".format(total_pixels) + " correct with an accuracy of {:.2f}%".format(accuracy*100))
I got two different accuracies, one for each loss function:
- Cross-entropy loss: “Got 19281846/20643840 correct with an accuracy of 93.40%”
- Dice loss: “Got 18088871/20643840 correct with an accuracy of 87.62%”
Note: To get even higher accuracy than 93%, I would have to further tune the parameters (weights, biases, dropout value, learning rate, etc). However, the intention of this project is to guide and aid medical experts, not replace them, so a failure rate of 7% is not that significant.
Step 12: Visualization
Although the accuracy is very high for this project, it’s very beneficial to visualize the results to ensure that the actual and predicted masks are very similar (in accordance with the high accuracy).
One important action is that we don’t want to visualize every single image and mask in the testing dataset (we’ll use this one since this dataset since it shows the high accuracy from the previous step). So, after iterating over the testing dataset and getting the visual_images
list and visual_masks
list, we can split it up to get the last ten images of each using [:10]
.
visual_images, visual_masks = next(iter(testingloader))
visual_images = visual_images[:10]
visual_masks = visual_masks[:10]
For every image and mask, we need to:
- Change the dimensions using
.unsqueeze()
and.permute()
because the previous loop for testing changed the dimensions and we need to undo those changes to input them into the model - Put the image through the model, convert the output (predicted mask) from a torch tensor to a NumPy tensor using
.detach().numpy()
- Make a copy of the predicted mask using
.copy()
. By making a copy, we can form a prediction with the threshold. The intention behind a threshold is to ensure high contrast between the foreground and background of the image. In other words, remember how there were some very faint pixels in the predicted mask? Essentially, we’re removing those pixels with the threshold.
The threshold is set by[np.nonzero(pred_t < 0.9)]
in the image array, which means that whatever pixels have a value less than 0.9, they’re set to 0. On the contrary, with[np.nonzero(pred_t >= 0.9)]
, whatever pixels have a value greater than 0.9 are set to 1.
for i in range(len(visual_images)):
image = visual_images[i]
mask = visual_masks[i] image = image.unsqueeze(1)
image = image.permute(1,0,2,3) output = model(image)
output = output.detach().numpy()
output = output.squeeze() pred_t = output.copy()
pred_t[np.nonzero(pred_t < 0.9)] = 0.
pred_t[np.nonzero(pred_t >= 0.9)] = 1.
pred_t = pred_t.astype("uint8")
Now that we have a list of 10 images, we can plot the image, actual mask, predicted mask, and predicted mask with a threshold. For this example, since 10 images would be very long when printing out, here’s how we would work with one image using mostly plt
, .imshow()
, and .set_title()
.
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 10))
image = image.squeeze()
image = image.permute(1,2,0)mask = mask.permute(1,2,0)
mask = mask.squeeze()ax[0, 0].imshow(image)
ax[0, 0].set_title("image")ax[0, 1].imshow(mask)
ax[0, 1].set_title("mask")ax[1, 0].imshow(output)
ax[1, 0].set_title("prediction")ax[1, 1].imshow(pred_t)
ax[1, 1].set_title("prediction with threshold")plt.show()
That’s the end of my code! Oh, wait…one more thing that I forgot to mention: for me, this project didn’t incorporate 12 steps. IT INCORPORATED 200+ STEPS 🤯. My article didn’t account for the fact that I got errors every day. Some of these errors took hours and some even took minutes. This is the reason why this project took me almost 2 months, more than 100 hours, and a LOT of resilience.
Anyways, it was definitely worth it and I learned so many skills in the computer vision field that will be beneficial to me in future projects!
Sources
- https://www.youtube.com/watch?v=57N1g8k2Hwc&t=336s
- https://www.youtube.com/watch?v=lu7TCu7HeYc&t=657s
- https://www.youtube.com/watch?v=lu7TCu7HeYc&t=682s
- https://arxiv.org/pdf/1505.04597.pdf
- https://arxiv.org/pdf/2105.07576.pdf
- https://arxiv.org/pdf/2006.14822.pdf
- https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47
- Stack Overflow 💻
- All my friends that supported me with this project and helped motivate me to keep going :D
If you’ve read this far, I want to give you a huge THANK YOU because I know this article may have been a bit long. If you enjoyed this article, feel free to give it a clap and read more of my intriguing articles here:
Also, feel free to reach out to me if you have any questions regarding this project or if you would like to set up a meeting to talk!
Hi, I’m Shiza, a 15 y/o computer vision enthusiast and an innovator at The Knowledge Society (TKS).
My Portfolio: https://tks.life/profile/shiza.charania#portfolio
My LinkedIn: https://www.linkedin.com/in/shiza-charania/
My Twitter: https://twitter.com/ShizaCharania
My Youtube Channel: https://www.youtube.com/channel/UC-yVSWP_BOh7UxbnyWH1vvg
Subscribe to my monthly newsletters: https://landing.mailerlite.com/webforms/landing/c5n5v7