Coding a CNN for Medical Imaging using TensorFlow 2

Esmitt Ramírez
MICCAI Educational Initiative
14 min readJul 24, 2020

--

Different X-rays images with degrading color focus on a possible lesions in the thoraxic part
Taken from CheXNet: Radiologist-Level Pneumonia Detection on Chest X-Rays with Deep Learning (2017)

Great! I already completed the Machine Learning course, and it is time to start coding in Python by myself “my own neural network” using the most powerful tool for students…Google!🤓 I realized there are several tutorials to classify dogs and cats images and not so many to classify pancreatic tissues from CT images or to segment coronary arteries.

According to the types of machine learning: supervised, unsupervised, or reinforcement learning; this guide is focused on Convolutional Neural Networks (CNN) for classification as a supervised network. The goal is to provide a technical introduction for executing CNN for medical imaging, highlighting some key features to consider when working with medical images. Indeed, several valuable resources on the Internet provide techniques and functions for classification, localization, detection, and segmentation using deep learning. Nevertheless, a subset of strategies for coding your CNN network from the design of its execution using TensorFlow 2 (TF for short) to its evaluation will be presented 🖖

I hope this guide will help other developers interested in machine learning in the fascinating world of medical imaging. This tutorial assumes that you posses a basic knowledge of deep learning and as well as Python 3.

An example scheme for a CNN showing different layers
Source: https://fr.mathworks.com/solutions/deep-learning/convolutional-neural-network.html

What do you need?

It is possible you are struggling with a research paper in your associated project that includes deep learning. My initial recommendation is the usage of official documentation of technologies used on that, followed by the medical imaging conferences as MICCAI, excellent pages as paperswithcode and more resources available on the Internet.

For this article you need the following tools and technologies:

  • Python 3, installed in your machine using virtual environments, a Docker container and a mechanism to install additional packages (e.g. conda, pip).
  • TensorFlow2, that you can execute running on CPU/GPU. Sometimes, for CPU you need building TF from its source to compile properly with optimization values for proper running.
  • The editor of your preference in your local machine or in the cloud. Plus, computational notebooks (e.g. Jupyter Notebook/Lab, Google Colaboratory) represent excellent choices too.
  • Moreover, you can use different notebook-based as the provided by Kaggle.

In any case, consider the limitations on data transference in your platform 🦾

What will you find here?

  • Numerous links to resources that you can consult later that might be useful in your path.
  • Steps to develop a CNN for binary classification employing medical images.
  • A subset of common metrics used in medical imaging.

In this guide first, the dataset to work with will be defined; next, the design and compiling the CNN using TF. Following this, the execution of the network training process with its hyper-parameters, and finally evaluation and prediction the model.🤖

1. Make the medical data great again

Depending on the different research fields, you may demand a massive amount of data so that your model can learn from it by identifying certain relations and frequent features to the objects. Features that allow the classification, generation, localization, and others using medical images. The book Visual Computing for Medicine [1] is an excellent starting point to discover more about medical volume data like imaging modalities, PACS/DICOM, Hounsfield units range, and more.

Technically, if data fit into memory it is possible to store it into NumPy arrays. However, I recommend using Dataset or TFRecords both part of the tensorflow.data API which enables to build complex input pipelines for TF programs. Now, on large datasets is necessary to employ data generators that load mini-batches to feed your deep learning model dynamically. The Generators build a pipeline from the storage to the CPU or GPU RAM loading data when they require it. These allow applying pre-processing functions to properly prepare data for the model 🧠

Back to the basics remember that coding conventions are relevant to each programming environment. The majority of published research present their findings (data, code and experiments) for its reproduction and usage. Thus, managing variables such as final_test_v2 might be impractical for upcoming developments of other researches. I decided to prepare data under the following regex /(x|y)_(train|test|validation)/g, where x represents the input images or volumes, y represents the associated label to x. On the other hand, the train, test and validation represent the training, testing, and validation data respectively.

Note that the cross-validation approach is not taken into account here, but given the limited data sample (in some cases) of medical imaging, methods like that should be considered. For instance, scikit.learn offers an efficient tool for predictive data analysis in classification, regression, clustering, and more. Moreover, it contains the model_selection package which implements functions to random split data into training and test sets, cross-validation, and others that could be useful in your research.

If you are using NumPy arrays, to convert them to Dataset, your code should look as follow:

Let’s details the code: x_train and y_train are NumPy array which contain all images/volumes and labels for the training of our network respectively. Note that the data for the test stage is stored in x_test and y_test. Then, the idea is to split into training and validation data using a percentage of the whole dataset. The variable count_training represents an integer that represents that value. For instance, if we have 100.000 images and we decided the distribution of 80% for training and 20% for validation, then the count_training is equal to 80.000 images for training and 20.000 images for validation. Finally, whether train_ds , validation_ds and test_ds are the Dataset objects used from the explained NumPy arrays.

Remember to always examine your data: how your data is distributed. To achieve that, valuable tools are pandas for a high-level data analysis and manipulation using tables (DataFrame is your friend), and seaborn library for visualization of data based on matplotlib. However, the displaying of your data depends entirely on your domain. Several libraries can help you with this process as Mayavi, itkwidgets, plotly, K3D, etc. or only using matplotlib.

More about data processing

Sometimes to find the dataset for our project entails a hard task. There are open-access medical repositories (e.g. Aylward, UCL) to explore public datasets, or sometimes it requires a compilation of various sites (e.g. COVID-19 data collection [2]). In any case, it is important to verify the conditions of data usage according to the site for patient anonymity and confidentiality 🤫

Despite the amount of data, there is an inherent factor to consider: the unbalanced data. Several medical data is unbalanced [3], requiring some techniques for data augmentation such as geometric transformation, color space augmentations, adversarial training, among others [4][5]. Always consider that data augmentation is performed by reducing the overfitting or memorization of training data.

TF provides an image data generator that performs on-the-fly data augmentation, like rotations, translation, zoom, shearing, and flipping, just before feeding the network. In general, medical images have no canonical orientation. However, in some cases you should be carefully considering your working domain. For instance, applying a horizontal flip over a chest X-ray could generate an image of a congenital anomaly called dextrocardia, which is a rare heart condition where part of the heart is on the right side instead of the left side.

Posterior-Anterior X-ray with situs inversus with dextrocardia
Situs inversus with dextrocardia X-ray. Source: MedPix — https://9sh.re/X9VSCTaYp

Another aspect to consider is the data normalization and standardization and their impact [6][7]. By normalizing all the inputs to a standard scale, we are allowing the network to quickly learn the optimal parameters for each input node, also avoiding mathematical errors associated with float number precision. About normalization, processing the data to a standard input helps control the variations in the location of the problem (e.g. tumors appear on certain parts of the breast) and scale of data (e.g. spacing and thickness values on volumes).

For this guide, we will assume that data fit into memory and it will be stored in NumPy arrays and Dataset (only for demonstration purposes).

2. Model design & Metrics

The design of a conventional CNN model is composed of convolutional layers (to extract features from the input image), pooling layers (to reduce the dimensionality of each feature map retaining the most significant information), and flattening layers (to convert into a linear array). Furthermore, dense layers (non-linear) are required to deeply connections, and for classification it is used as the last layer with the proper activation function.

The easiest way to design a model in TF is using the Sequential model, as is shown in the following code:

On the previous code, notice the number of filters is 128 and 64 for the two convolutional layers. Moreover, the design only shows an example of a CNN which takes as input RGB images of size 64 x 64 pixels and its labels is 0/1 representing normal/disease respectively. To continue, the summary function prints the following:

Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 128, 64, 64) 9728
_________________________________________________________________
conv2d_1 (Conv2D) (None, 64, 64, 64) 73792
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 64, 32, 32) 0
_________________________________________________________________
batch_normalization (BatchNo (None, 64, 32, 32) 128
_________________________________________________________________
flatten (Flatten) (None, 65536) 0
_________________________________________________________________
dropout (Dropout) (None, 65536) 0
_________________________________________________________________
dense (Dense) (None, 1) 65537
=================================================================
Total params: 149,185
Trainable params: 149,121
Non-trainable params: 64
_________________________________________________________________

Regardless of designed CNN, I prefer using the following design which gets the same output (with different layer names):

That scheme used the Functional API which provides more flexibility than Sequential API. It is possible to create more complex networks and getting information about each layer. In addition, the model can be displayed (remember installing pydot) using the following code:

An image that shows a graph representing the design of the network
Plot of the network using the function plot_model

Implementing the design as a function allows passing parameters to try different configurations of hyper-parameters for the model. The subsequent step consists of compiling the model to define the loss function, the optimizer function and the metrics.

Let’s talk about metrics

Predominantly, in Medical Imaging the objective functions are related to discovering a disease on a patient. When talking about Deep Learning, the first metric to appear is the accuracy which basically represents the proportion of the total examples correctly classified. In terms of Medical Imaging, the accuracy might be interpreted as the probability that a model is correct and a patient has a disease plus the probability that a model is correct and the patient is healthy (normal) 💉

Also, most of metrics are based on the following:

  • True positive (TP): number of instances that correctly predicted.
  • False negative (FN): number of instances that incorrectly predicted.
  • True negative (TN): number of negative instances that predicted correctly.
  • False positive (FP): number of negative instances incorrectly predicted.

With these values, it is possible to extract two key metrics: sensitivity and specificity. These metrics are also known as true positive rate and true negative rate. The sensitivity refers to the probability that a model classifies a patient having a disease when it being present. The specificity refers to the probability that a model classifies a patient being normal given that it is normal. The sensitivity and specificity are inversely proportional to each other.

Another metric is the positive predictive value (PPV), which is the probability that a model predicts positive on a patient and it in fact has the disease. Similarly, the probability that a model predicts negative and it is normal is called negative value (NPV). These values are also known as precision and recall respectively. The following image summarizes these computations, where GT means ground truth and output refers to the prediction.

The formulas to compute the sensitivity, specificity, PPV and NPV
Basic metrics on deep learning (in the medical domain)

Among other metrics, the AUC is equal to the probability that a classifier will rank a random positive sample higher than a random negative sample. The AUC metric (area under the curve) creates four local variables, TP, TN, FP, and FN, used to compute it.

The code below shows how to follow these metrics during the training process of the model. When the design of the model is accomplished, the following step is to compile to be ready for training. To compile it is required the loss function to apply (cross-entropy loss between true labels and predicted labels), the optimizer (Adam, an optimization algorithm to update network weights iteratively), and the metrics mentioned before.

TensorFlow offers different types of loss functions. Actually, TF implements different loss functions as the well-known to address class imbalance. This function is part of an extra functionality called TensorFlow Addons. In fact, it offers activation functions (e.g. Hardshrink, Sparsemax), layers (e.g. Maxout, Adaptative Max Pooling), metrics (e.g. F1-Score, Cohen’s Kappa), optimizers (e.g. Lazy Adam, Yogi), and other loss functions such as Triple Hard, Pinball, etc.

3. Training

In TF, training, evaluation, and prediction work exactly in the same way for models. As mentioned before, it is not possible to feed training data into the network in one pass for memory limitations. Then, it is necessary to use epochs to elapse when an entire dataset is passed forward and backward through the network once. Also, the dataset must be divided into batches. Thus, the batch size is the total number of training samples presented in a single batch. Lastly, an iteration is an update of the model’s weights during training. The number of iterations is the number of batches needed to complete one epoch 🤔

Now, which is the right number for epochs, batch size, and iterations? There is no magic rule for choosing these values. For example, a small batch size will introduce a high degree of variance within each batch, also a small sample is not a good representation of the problem from the dataset (consider the unbalanced data in medical imaging such lesions vs. no-lesions). In contrast, a large batch size may not fit in memory, and it will have the tendency to overfitting.

After defining parameters for the training, it is possible to define some callbacks. Callbacks are used during model training and TF offers useful functions to consider during the process. An example is the EarlyStopping which stops the training when a monitored metric has stopped improving.

Now, the training is achieved using the fit function, defining its parameters:

Output during training could be suppressed using the verbose parameter. Part of the output should look like this,

...
Epoch 18/60
1012/1012 [==============================] - ETA: 0s - loss: 0.4979 - tp: 3800.0000 - fp: 3390.0000 - tn: 45164.0000 - fn: 12385.0000 - accuracy: 0.7563 - precision: 0.5285 - recall: 0.2348 - auc: 0.7480 - sensitivity: 0.5719
...

Notice the values of metrics used are being shown during the process, and the history object holds a record of the loss values and metrics values during training. It is possible to visualize them! 🧐

Plotting epochs vs. loss value of train and test
A plot of loss function using a log scale to show a wide range of value

To examine our metrics, they could be plotted also.

Six images representing six metrics: loss, precision, recall, AUC, TP, and sensitivity
Six metrics of our network: loss, precision, recall, AUC, TP and sensitivity

Great! Then the model is already trained, but can I get better metrics results? I do not know 💀, but you can try different configurations of your hyper-parameters automatically. Using the function GridSearchCV is possible to generate candidates from a grid of parameter values to tune your model determining the optimal values. Now, it’s time to evaluate the model.

4. Evaluation & Prediction

There are research papers that can guide you on how to properly evaluate a CNN for medical imaging according to your domain (e.g. cardiac image segmentation, retinal vessel detection, etc.). For now, the first step consists in the evaluation of our test dataset.

Evaluate on test data
397/397 [==============================] - 11s 27ms/step - loss: 0.5186 - tp: 860.0000 - fp: 845.0000 - tn: 18181.0000 - fn: 5482.0000 - accuracy: 0.7506 - precision: 0.5044 - recall: 0.1356 - auc: 0.7195 - sensitivity: 0.5247

{'accuracy': 0.7505912780761719,
'auc': 0.7195032238960266,
'fn': 5482.0,
'fp': 845.0,
'loss': 0.5186316967010498,
'precision': 0.5043988227844238,
'recall': 0.1356039047241211,
'sensitivity': 0.5247030258178711,
'tn': 18181.0,
'tp': 860.0}

Using the function predict enables to obtain the result of the model with the given input. With this result, the evaluation compares the ground truth with the predicted values determining how these are different. The variables y_train_pred and y_test_pred are created corresponding to the prediction (i.e. inference) of training and test dataset respectively. Now, lets plots some useful images 📈

Lesions Detected (True Negatives):  8919
Lesions Incorrectly Detected (False Positives): 3814
No-Lesions Missed (False Negatives): 2576
No-Lesions Detected (True Positives): 10059
Total Lesions: 12635
A square divided in 4 sections:TP, TN, FP and FN
Confusion matrix for a model. To better visualization, this corresponds to an improved model.

Using the confusion matrix for binary classification (i.e. variable cm in the previous code), it is feasible to extract the TP, FP, FN, and TN values. Besides, it is possible to compute them and other metrics using the TensorFlow functions. For example, to calculate the precision of the training set:

Another useful tool is the ROC (receiver operating characteristic) curve which is a model-wide evaluation measure based on specificity and sensitivity.

The ROC curve displayed (FP in X-axis vs TP in Y-axis)
ROC curve

There are several common metrics in medical imaging provided by sklearn. For binary classification, the function classification_report computes some of them. Other evaluation metrics are directly related to the final goal of the network [8][9] (e.g. segmentation, registration, etc.).

Diverse methods could be applied to display model’s results including its composition (e.g. layers and weights). A good starting point for visualizing layers is the work of Zeiler & Fergus [10], following some techniques such as Saliency Maps, Score-Cam, Grad-Cam, Grad-Cam++, Activation Maximization, CNN Fixations and more.

Final Remarks

This guide introduced an overview to develop a CNN for classification of medical imaging demonstrating different options to apply using TF. A variety of options were presented to support others in the deep learning race. There are considerable experiments that we can perform to improve our network. As you already notice, learning means adjusting the parameters! 👽

In the healthcare domain, there is a broad variety of exciting and future-looking applications based on AI and machine learning. Actually, you can review these papers [11][12][13] for further reading on the deep learning challenges in medical imaging 🏥

There are more details not considered in this guide such the Transfer Learning using pre-trained networks in TF (both feature extraction and fine-tuning). In fact, TF offers a collection of pre-trained networks to be used in TensorFlow Hub.

For further information about training and evaluation using TF, you can check the official documentation.

From a geek to geeks

References

[1] Bernhard Preim and Charl Botha. Visual Computing for Medicine: Theory, Algorithms, and Applications, 2nd edition, ScienceDirect, 2014.
[2] Joseph Paul Cohen, Paul Morrison and Lan Dao. COVID-19 image data collection, arXiv 2003.11597, Github repository, 2020.
[3] Belarouci, Sara and Chikh, Mohammed. Medical imbalanced data classification, Advances in Science, Technology and Engineering Systems Journal, vol. 2, pp. 116–124, 2017.
[4] Shorten, C. and Khoshgoftaar, T.M. A survey on Image Data Augmentation for Deep Learning, J Big Data 6, 60, 2019.
[5] Hussain Z, Gimenez F, Yi D and Rubin D. Differential Data Augmentation Techniques for Medical Imaging Classification Tasks, AMIA — Annual Symposium proceedings. AMIA Symposium, pp. 979–984, 2017.
[6] F. Ciompi et al. The importance of stain normalization in colorectal tissue classification with convolutional networks, 2017 IEEE 14th International Symposium on Biomedical Imaging (ISBI 2017), Melbourne, VIC, 2017, pp. 160–163, 2017.
[7] Reinhold, J. C., Dewey, B. E., Carass, A. and Prince, J. L.. Evaluating the Impact of Intensity Normalization on MR Image Synthesis. Proceedings of SPIE — the International Society for Optical Engineering, 10949, 109493H, 2019.
[8] Taha, A.A. and Hanbury. A. Metrics for evaluating 3D medical image segmentation: analysis, selection, and tool, BMC Med Imaging 15, 29, 2015.
[9] Sara Moccia, Elena De Momi, Sara El Hadji and Leonardo S. Mattos. Blood vessel segmentation algorithms — Review of methods, datasets and evaluation metrics, Computer Methods and Programs in Biomedicine, vol. 158, pp. 71–91, 2018.
[10] Zeiler M.D. and Fergus R. Visualizing and Understanding Convolutional Networks. In: Fleet D., Pajdla T., Schiele B., Tuytelaars T. (eds) Computer Vision — ECCV 2014. Lecture Notes in Computer Science, vol 8689. Springer, 2014.
[11] F. Altaf, S. M. S. Islam, N. Akhtar and N. K. Janjua. Going Deep in Medical Image Analysis: Concepts, Methods, Challenges, and Future Directions, in IEEE Access, vol. 7, pp. 99540–99572, 2019.
[12] Saraf V., Chavan P., Jadhav A. Deep Learning Challenges in Medical Imaging. In: Vasudevan H., Michalas A., Shekokar N., Narvekar M. (eds) Advanced Computing Technologies and Applications. Algorithms for Intelligent Systems. Springer, Singapore, 2020.
[13] A. S. Panayides et al. AI in Medical Imaging Informatics: Current Challenges and Future Directions, in IEEE Journal of Biomedical and Health Informatics, vol. 24, no. 7, pp. 1837–1857, July 2020.

Originally published at https://www.ecode.dev on July 24, 2020.

--

--

Esmitt Ramírez
MICCAI Educational Initiative

Software R&D engineer and computer scientist. I code for fun and food! One of my projects is ecode.dev to spread the Computer Science knowledge for programmers