Train a Neural Network to Detect Breast MRI Tumors with PyTorch

A practical tutorial for medical image analysis

Nick Konz
Towards Data Science

--

An example breast MRI scan from our dataset.

Most research in computer vision with deep learning is conducted on common natural image datasets such as MNIST, CIFAR-10 and ImageNet. However, an important application area of computer vision is for medical image analysis, where deep learning has been used for tasks such as cancer detection, organ segmentation, data harmonization, and many other examples. However, medical image datasets can often be more involved to “plug” into deep learning systems than natural image datasets. Here I provide a practical, step-by-step tutorial on how to use deep learning for a simple medical image analysis task, from data acquisition all the way to model testing.

In this post I will show how to train a neural network classifier to detect tumors in breast MRI images, using PyTorch. In my previous post (found here on my lab’s blog), I introduced my lab’s publicly-available Breast MRI dataset, and how to interact with raw medical imaging data using Python. I demonstrated how to extract and sort image files in a format and labeling that will be useful for training and testing our model with PyTorch. This dataset was originally presented in the paper:

Saha, A., Harowicz, M.R., Grimm, L.J., Kim, C.E., Ghate, S.V., Walsh, R. and Mazurowski, M.A., 2018. A machine learning approach to radiogenomics of breast cancer: a study of 922 subjects and 529 DCE-MRI features. British journal of cancer, 119(4), pp.508–516. (A free version of this paper is available here: PMC6134102)

With our data in place, we can proceed to building, training, and testing our deep classifier. Beyond the typical way of measuring classifier performance in computer vision (total prediction accuracy), I will also show how to further analyze the performance of our model with metrics common to medical image analysis.

All code for this tutorial can be found here; some experience with Python is needed, and PyTorch knowledge is helpful but not essential.

(1) Building a data loading pipeline

Deep learning is data-driven, so having a reliable framework to work with (image) data is essential. PyTorch (the torch and torchvision libraries in Python), among other things, allows for the efficient manipulation and management of numerical matrices, and is one of the most popular deep learning frameworks (as neural networks operate and learn via many matrix multiplications and additions). Images can also be described with large matrices of numbers, where the dimensions of the matrices correspond to the size of the image, and each element of the matrix is a pixel intensity value. As such, it is extremely helpful to abstract all image data loading/processing and neural network operations with PyTorch, which comes with myriad convenient modules and tools for these and more.

The central objects that we will use are the Dataset and DataLoader of torch.utils.data. While Dataset allows for the easy storage and indexing of data samples and labels, DataLoader makes it possible to easily access these samples in a way that integrates very well with how we will train and test neural networks. Please see PyTorch's tutorial for more information. To start, let's import the needed libraries and objects:

Next, let’s define some constants:

Datasets

First, we must define our own Dataset for the DBC dataset, called DBCDataset. In summary, the important methods defined in DBCDataset are:

  1. the create_labels() method assigns an easily-accessible label to each image in the dataset,
  2. the normalize() method normalizes images to the pixel value range [0,255], as it's important to standardize data for deep learning,
  3. the __getitem__() method is required for Datasets, and describes how data is obtained from the dataset with an index (or indices)

Check out the code block below, where I’ve added comments to explain each step.

From here we can simply create an instance of the dataset with:

Output:

building DBC dataset labels. 
5200

Training, validation, and test sets: what are the differences?

To develop our classification model we will need to split our dataset into training, validation, and test sets. For each image in the datapoint, there is an associated label that we want the classification model to predict. But what are the differences between these data subsets?

  1. The training set is used to provide the model with examples of how to make predictions; this is what the model “learns” from. The learning algorithm is just modifying the neural network parameters to minimize the average prediction error on the training set.
  2. The validation set is used to estimate how well the model performs at predicting labels for new data that it has not learned from (the ultimate goal of developing this model, also known as generalization). This validation prediction error is used to select at which point in training we’d like to save the model: we want to save the model when the validation error is lowest. You can also use the validation set to select hyperparameters or settings for the training algorithm that are not learned from the training set.
  3. The test set, like the validation set, is also used to estimate the generalization ability of the neural network on new data; however, this must be separate from the validation set because the validation set itself was used to select the final model, and we want an unbiased estimate of generalization ability

With a dataset size of 2600 + 2600 = 5200, a typical set of percentages for splitting the dataset into training/validation/testing could be something like 80%/10%/10%, which results in a training set of size 4160, and validation and testing set sizes of 520, each. Practically, we can extract these subsets from our dataset with the useful function torch.utils.data.random_split(), which randomly splits up the full dataset into subsets:

Output:

5200 
4160 520 520

Dataloaders

We’ve created PyTorch Datasets for model training, validation, and testing, which is most of the work for our data-loading pipeline. To finish, we will need to create PyTorch Dataloaders to conveniently access images from our datasets. But first, a quick note on batch sizes.

Batch sizes and computation devices

While we could train neural networks on one image at a time, this would be prohibitively slow, as they usually require learning from hundreds or thousands of images, many times. Instead, we can train on batches of multiple images at the same time, limited by the memory capacity of our computation processing device; for example, GPUs (graphical processing units), which are specially designed for image manipulation.

Below, we will create Dataloaders for each of our three data subsets. We will use a batch size of 200 for the training set, but this is very much dependent on the CPU or GPU hardware that you use for computations. For most realistic computer vision applications, a GPU is required, as CPUs are intractably slow; as such, we will be using an 8 GB NVIDIA GTX 1070. The device that we will load data on to can be specified with:

Output:

running on cuda

Now, let’s create our dataloaders:

Next, to ensure that our results will be reproducible, we will fix all random seeds with:

With this, we are ready to introduce and build our classification neural network!

(2) Loading a neural network

Neural networks are really just functions with many, many parameters (a.k.a. dials to tune). These parameters are learned from lots of data to tune the network to best approximate the function that we’re trying to emulate. For example, an image classification neural network, like the one we will work with, is trained to take images as input, and outputs the correct class identity of the image, e.g., whether a breast image is cancerous. The many successive computational layers of neural networks allow them to learn very complicated functions that would be practically impossible to hand-design.

Convolutional neural networks were particularly designed to learn to detect the spatial patterns found in images, so are especially well-suited for our task. Today, we will work with a very popular modern neural network architecture called a residual network, or ResNet for short. In fact, the original ResNet paper is one of the most cited papers of all time, with over 120,000 citations as of June 2022, according to Google Scholar. We will be using a particular version of the ResNet model known as ResNet-18, the details of which are beyond the scope of this tutorial. ResNet-18 and models like it can be easily loaded (untrained) with PyTorch’s torchvision.models library, as:

Here we also imported PyTorch’s neural network library torch.nn. Next, we will load a ResNet-18 to work with (as resnet18 is a class):

ResNets are designed to work with color, three-channel images. However, our MRI slices are one-channel, so we will need to modify our net to take one-channel inputs. This can be accomplished by modifying the input layer of net as:

Finally, we need to load the net onto our computation device:

With that, our network is ready to be trained to classify our images. Let’s set up a training pipeline!

(3) Setting up for training

Neural networks “learn” by minimizing the average error of making predictions on the entire training set. Each iteration, or epoch, over the training set, the parameters are adjusted in order to perform better at the next iteration. The change for each parameter is determined by the backpropagation algorithm, which adjusts each parameter to make the steepest decrease in error at the given iteration, over average (this process is known as stochastic gradient descent or “SGD”).

There are a couple of things that we will need to define in order to make this happen. First, we must define this prediction error, also known as loss. For the task of classification, the loss that we need is nn.CrossEntropyLoss(), which increases as the network predicts more incorrect classification of images in the training set. This is what we want to minimize in training; we can define it as:

Next, we want to define the error minimization algorithm that we will use; again, this is stochastic gradient descent, or SGD; there are others of course, but SGD is the most basic and will serve fine. When creating an instance of SGD, we will need to tell it which parameters we will minimize (the parameters of net, and the learning rate. The learning rate (lr) is a fixed constant, that basically determines the approximate size of adjustments made to parameters during learning. A good choice of learning rate can vary depending on the task, data, network, and other factors, but for now, we will choose lr=0.001.

Finally, let’s set the number of training epochs (passes over the entire training set) to 100:

(4) Train and validate your model!

We now have everything that we need to train our classification model. As mentioned earlier, on each training epoch we can evaluate the model on the validation dataset to estimate how well it will perform on unseen data. Then, we save our final trained model as the model found to have the best performance/classification accuracy on the validation set during training.

To do this in practice, we can create a copy of our training model net, and save it as a separate network net_final. Let's go ahead and initialize that:

Finally, we can create and run our training loop, with everything that we discussed earlier, in the code below (with comments at each step). In practice, we will use classification accuracy as our measure of prediction error, i.e., the percentage of images in the given dataset that are classified correctly by the network. One subtlety to deal with is that the network actually outputs probabilities for the input image being in each class. As such, the single predicted class is just given by the highest probability class.

We can also store our accuracy vs. epoch data for the training and validation sets, in order to observe how the model evolves through training.

Alright, let’s go ahead and train our model! This may take some time, depending on the strength of your computation device. I’ve also included some code to log each step of training.

Output:

### Epoch 0:
21it [00:16, 1.30it/s]
Training accuracy: 0.23557692307692307
100%|████████████████████████████████████████████████████████████████| 52/52 [00:01<00:00, 29.09it/s]
Validation accuracy: 0.5153846153846153
Validation accuracy improved; saving model.
### Epoch 1:
21it [00:15, 1.37it/s]
Training accuracy: 0.5454326923076923
100%|████████████████████████████████████████████████████████████████| 52/52 [00:01<00:00, 28.46it/s]
Validation accuracy: 0.551923076923077
Validation accuracy improved; saving model.

… (lines not shown) …

### Epoch 98:
21it [00:14, 1.40it/s]
Training accuracy: 0.9992788461538461
100%|████████████████████████████████████████████████████████████████| 52/52 [00:01<00:00, 29.27it/s]
Validation accuracy: 0.925
Validation accuracy improved; saving model.
### Epoch 99:
21it [00:15, 1.39it/s]
Training accuracy: 0.9992788461538461
100%|████████████████████████████████████████████████████████████████| 52/52 [00:01<00:00, 28.57it/s]
Validation accuracy: 0.9230769230769231

Let’s see how our model’s performance evolved over time, with a simple plot via matplotlib:

Output:

The model obtained a validation accuracy of 92.5% once fully trained, which means that out of the validation set of 520 images, it correctly classified about 480 of them as either cancerous or non-cancerous.

You may also notice that the model overfit to the training set. This could have been mitigated with some sort of regularization, but that is beyond the scope of this introductory tutorial.

This performance is good, but we will only know the true ability of the model to classify new data by evaluating it on the test set, as follows.

(5) Testing your best model

Now that our model is trained, how well does it do on the test set? We can test this with the following code, where I also show a few classification examples; this is very similar to how we evaluated on the validation set.

A Word on Measuring a Classifier’s Performance in Medical Image Analysis

In medical image analysis, it is common to report further performance metrics than solely classification accuracy, to better analyze how the classifier is doing. A false positive (FP) is when a classifier misclassifies a negative (cancer-free) image as positive, and a true positive (TP) is when a positive image is correctly classified. Let’s estimate these as well in our code:

Output:

Example Images: 
Target labels: [0, 0, 0, 1, 1, 1, 0, 0, 1, 0]
Classifier predictions: [0, 0, 0, 1, 1, 1, 0, 0, 1, 0]
Test set accuracy: 0.9442307692307692
238 true positive classifications, 19 false positive classifications

On our test set of 520 unseen examples, we got a prediction accuracy of 94.4% for our cancer detection task, or only about 30 misclassifications! Out of 257 positive (cancer) detections, 238, or about 93%, were true positives (correct), while 19 (about 7%) were false positives (incorrect).

Discussion

When designing automated approaches for safety-critical applications such as medical imaging, it is essential to examine possible risks and limitations. An example of this is the possibility of false positives: if a cancer detection model like this one was used clinically, a positive detection would immediately warrant further study of the patient, so a false positive could be misleading. A related possibility is that of false negatives: where a tumor is completely missed by the detection model. While our model had a high prediction accuracy overall, it was not perfect, and so should not be completely trusted for all diagnostic decisions. This is why a common paradigm for the development goal of computer-assisted diagnosis (CAD) devices is to assist, not replace, the radiologist.

It is also important to remember that models trained with deep learning are completely data-driven: they learn exactly according to how the training set was labeled. For example, I chose to label the dataset in this tutorial by assigning each 2D slice of a 3D MRI scan to a class of either positive (containing a breast tumor annotation) or negative. However, even if some slices from a 3D scan are found to be negative, that does not indicate that the entire scan/patient is negative for cancer: there could be other slices in the scan that are positive. This is an example of how all predictions made by CAD systems should be stated clearly and quantifiably.

There are many possible additions that could be made to improve this detection model. As we saw in the training evolution plot in section (4), it appears that our model overfit to the training set (indicated by the gap between training accuracy and validation accuracy). This could be mitigated or prevented with a number of techniques; see this tutorial for a good introduction to how to do this in PyTorch. Another improvement that we could make would be to train our model not to classify each 2D MRI slice as possessing a tumor somewhere, but to precisely locate any possible tumors within slices; a task known as object detection in computer vision. I did not explore that in this post because object detection is a considerably more nuanced problem than classification, but it would certainly be possible with our dataset because it contains tumor location labels/bounding boxes. A good starting point could be the PyTorch implementation of the Faster R-CNN object detection model.

Conclusion

In this tutorial, I showed how to train a breast MRI classification model via deep learning with PyTorch, on our DBC-MRI dataset.

In my previous post on my lab blog, I introduced the DICOM medical imaging datatype, showed to obtain the data from the Cancer Imaging Archive, and showed how to extract images from the data in a format useful for PyTorch.

In this post, I showed how to load, train and test a classification neural network on real breast MRI data.

I wrote these blog posts to provide an introductory example of how to use neural networks for the realistic application of medical image analysis. This only scratches the surface of the wide array of medical image analysis (MIA) applications of deep learning. For those interested in learning more, check out:

  1. The most recent proceedings of the MICCAI conference.
  2. The journals Medical Image Analysis and Transactions on Medical Imaging.
  3. My lab’s website, and scholarly publications by my lab’s advisor and myself.
  4. My Twitter and my lab’s Twitter.

Thanks for reading!

--

--

Machine Learning Ph.D. student at Duke University, specializing in medical imaging.