Medical Imaging Analysis using PyTorch
How to perform spinal cord gray matter segmentation using PyTorch medical imaging framework, MedicalTorch.
I truly believe that artificial intelligence (AI) will shape our future and will bring tremendous impact and applications in industries such as health and agriculture. One of the things that I aim to achieve with dair.ai is to discuss interesting open-source AI technologies that help to address important problems such as medical diagnosis and personalized learning. One of the tools that have caught my attention this week is MedicalTorch (developed by Christian S. Perone), which is an open-source medical imaging analysis tool built on top of PyTorch. It contains a set of loaders, pre-processors and utility functions to efficiently and easily analyze medical images such as those acquired from magnetic resonance imaging (MRI) scans.
In this post, I will summarize some of the functionalities offered by the medicaltorch
library and how it can be used to conduct medical imaging analysis. Specifically, this will be a tutorial on how to perform spinal cord gray matter segmentation using a technique based on convolutional neural networks (CNNs). The figure below shows a snapshot of the type of data we will be exploring in this post and what tasks we will perform. The top row of the figure shows the original MRI images and the bottom displays the crop of the spinal cord (annotated in the green rectangle).
Once we have re-sampled, cropped, and pre-processed the MRI data with the built-in functions available through medicaltorch
, we can train a segmentation model using CNNs. The figures below show an axial-slice to segmentation pair. In simple terms, our models need to be trained to get good at predicting those segmented parts as shown on the right.
Prior to MedicalTorch, I have never actually used a medical imaging framework, so this will be a walk-through based on what I have learned so far. Note that this tutorial is based on Google Colab notebooks (I have shared the link to the tutorial at the end of this post). You can also use the snippets of code to run on your local notebook and ignore the Google Colab parts.
Getting Started
This tutorial assumes a basic knowledge of Python and PyTorch. In order to get started, you need to install MedicalTorch as follows:
If you are working on Google Colab, you need to install the following libraries (refer to this post if you are not familiar with Colab):
By installing MedicalTorch, it automatically installs the necessary version of PyTorch for you. That’s nice! Now let’s get started!
This tutorial was adapted from the example posted by the author on the project website, which you can find here.
Let’s import a few libraries:
Data and Libraries
It’s not mentioned in the documentation, but you will need to install a few extra libraries like tensorboardX
if you want to display and analyze your results using the Tensorboard tool. For the purpose of this tutorial, we exclude this part. Also, you will need to manually download an MRI dataset known as the GM SC Challenge. It contains MRI scans of healthy subjects’ spinal cord.
Data Exploration
Before we build and train our model, let’s explore the MRI dataset first. Note that I have stored the data on my personal Google Drive, so you need to link it to the appropriate folders on your drive. Let’s look at one sample (an axial-slice) from the dataset. The preprocessing module mt_datasets.SegmentationPair2D
can be used to read and convert the data in a format that we can better explore in our environment:
The above code plots the following image:
PyTorch DataLoaders
Let us try to use the transformation functions offered in MedicalTorch. PyTorch offers a native “transforms” module that helps us to stack up and apply many transformations to our data. In the code below, we first re-sample the dataset so that all samples are of the same size and then apply a crop filter, followed by a type transformation (to tensor format).
As you can see, we have loaded the data and printed out the size of a mini-batch. The mini-batch contains 4 images with dimension: 200 X 200 px. You can also visualize by batches (refer to the notebook below for source code).
Constructing The Segmentation Model
We explored the images above, now we want to build the gray matter segmentation model with the MRI spinal cord images. Let’s define a helper function that helps to decide the final predictions of the model.
And here are all the transformations that we apply to both the training and validation dataset:
We load the dataset into our environment and split it into training and validation portions. Again, note that you need to change the directory to point to your own data source.
Here are the final data loaders that we will use during training:
Models and Parameters
Below we declare our model and hyperparameters. Note that we are using GPU in this model. The model used below refers to the U-net convolutional-based architecture proposed by Ronneberger et al., 2015, which essentially aggregates semantic information to perform the image segmentation in the upper layers. See a figure of the U-net architecture below.
We build helper functions, such as accuracy
, to produce the desired performance metrics for the model:
Training
Now we finally train the model for spinal cord gray matter segmentation. We report the training and testing accuracy below and train for 10 epochs only.
A lot happened in the code above! I would recommend that you spend some time digesting it and understanding every bit by going through the complete notebook provided at the end of this post. The final results produced by the U-net model is as follows:
Train loss: -0.9267, Training Accuracy: 97.0087
Val Loss: -0.9262, Validation Accuracy: 99.5775
The validation accuracy is 99% after 10 epochs! Not bad for what is considered a very complex and important task.
Final Words
In summary, you learned how to process MRI image scans using a neat and powerful tool known as medicaltorch
. In addition, you learned how to pre-process, prepare and load the data using MedicalTorch’s and PyTorch’s built-in data loader functions. Finally, you trained a model based on convolutional neural networks to conduct spinal cord gray matter segmentation. Feel free to explore more of the utility functions provided in the medicaltorch
API and explore different types of datasets.
Contributions
I also submitted a pull request where I tried to improve the documentation a bit and suggested minor changes to MedicalTorch’s project page. I hope it helps jump-start others into the cool and fun world of MRI data analysis. If you have any trouble running the code above or spot any bugs, please comment them below.
Things to Try
You can try the following tasks:
- Add suggestions to improve the attached notebook. And then we can propose it as a PR to the official documentation website of MedicalTorch.
- Apply the models to the testing dataset (I only did validation)
- Perform segmentation using the others models offered by the
medicaltorch
library - Apply the models to a different type of dataset and contribute to the tutorial section of the
medicaltorch
documentation
References
- Spinal cord gray matter segmentation using deep dilated convolutions
- U-Net: Convolutional Networks for Biomedical Image Segmentation
Final Notebooks
Please find the Google Colab tutorial here.
One last quick thing: any sort of engagement (like follows, shares, 👏👏👏, and feedback) will make a huge difference for the future and sustainability of the dair.ai publication. So I will deeply appreciate any of that in advance.