The Power of Checkpointing in Medical Imaging

Upasana Bharadwaj
MICCAI Educational Initiative
4 min readAug 24, 2019

Developing state-of-the-art deep learning models in a data scarce world

Have you ever tried to train a spine segmentation model with only a few dozen labeled examples? What about early detection of pancreatic cancer from a hundred CT scans? The medical imaging community is plagued with limited access to high quality labeled data; but every little bit counts when it comes to clinical outcomes. This post discusses one specific technique in deep learning that ameliorates the effects of data sparsity: transfer learning with opportunistic initialization and checkpointing.

Why are medical images so hard to find?

The entire pipeline, from image acquisition to expert analysis by a radiologist is prohibitively expensive. Combined with privacy and sensitivity concerns around medical records, only a subset of the data makes its way into machine learning. To top it all off, the occurrence of even medically prevalent conditions is rare in the context of training robust models.

Case in point: spine curvature estimation

State-of-the-art models within medical imaging such as the lung cancer model that achieved radiologist-level performance are typically composed of two stages: landmark/region-of-interest detection followed by a classifier (or regressor if the prediction is real-valued and cannot be discretized).

Spine curvature estimation from radiographs is an ideal case study because: (1) it is a time-consuming task for doctors; (2) it is a well-researched topic with even a MICCAI 2019 Challenge; (3) any promising solution will likely follow the same two-stage approach described above.

Landmark Detection

The goal of landmark detection is to identify regions within the image that correspond to clinically relevant areas such as vertebrae in the spine. It is one of the most essential components of an end-to-end system, since landmarks not only enhance the accuracy of second-stage models but also provide a basis for interpreting subsequent predictions in a clinical context.

Accurate detection models require highly granular labels, such as a series of landmark coordinates as illustrated below.

Training Example from SpineWeb Dataset 16

The figure represents a typical radiograph used for spine curvature estimation. The training data used in the MICCAI Challenge consists of ~600 anterior-posterior x-ray images, the landmarks ( in red), and Cobb (curvature) angles.

Landmark/object detection is a complex task; see the RetinaNet paper for an in-depth characterization of what makes the problem so challenging. The bottom line is that cutting edge models with millions of parameters require at least thousands of labeled examples.

Techniques such as the focal loss help with class-imbalance, and label-agnostic approaches that rely on data augmentation (see uda paper) can further enhance overall accuracies.

But why not also rely on very simple principles from transfer learning. The goal is to use every single x-ray that is available…

State-of-the-art detectors such as RetinaNet rely on an underlying convolutional neural network (CNN) for extracting relevant features. Although the size and depth of the CNN largely depends on the task’s complexity, it is widely accepted that deeper and larger models tend to yield more accurate results. Moreover, convolutional layers progressively learn concepts, all the way from simple geometric features to semantic structures.

So the recipe for transferring knowledge in deep learning is rather obvious: (1) train a convolutional model on any data that “resembles” our problem space; (2) initialize a new model by copying (a subset of) weights from the pre-trained model; (3) fine-tune the new model with the desired dataset and its corresponding labels. Note that this technique is so powerful, the natural language processing (NLP) community has released pre-trained checkpoints for models such as BERT, which can then be fine-tuned for a wide variety of tasks ranging from question-answer systems to sentiment analysis.

Stanford spearheaded the “checkpointing revolution” in medical imaging by open-sourcing large datasets such as MURA and CheXpert. The latter has over 200,000 chest x-rays, most of which capture the spine. To exploit significantly larger datasets such as CheXpert:

  1. Train a CNN feature extractor on CheXpert using a state-of-the-art model architecture such as DenseNet.
  2. Initialize a RetinaNet detection model with the weights from the pre-trained DenseNet model.
  3. Fine-tune the RetinaNet detection model with ~600 labeled examples from SpineWeb (AP x-rays with vertebrae landmarks).
  4. Incorporate all orthogonal techniques (e.g. UDA) for data efficiency.

Curvature Estimation

Rinse and repeat is the motto of deep learning. Using the exact same CNN architecture for regression can be extremely powerful. Spine curvature is measured by Cobb angles, a very involved and time-consuming process. Training a new model from scratch requires significantly more labeled examples, and acquiring each new label can be expensive.

Intuition suggests that the landmark CNN learnt features that are highly relevant for Cobb angle estimation as well, so in the second stage:

  1. Initialize a regression model with weights from the DenseNet/CNN feature extractor used in landmark detection.
  2. Fine-tune with ~600 labeled examples of Cobb angles.

What Really Matters?

  1. Data is more powerful than a specific model architecture; mold your model to the data, not the other way around.
  2. Create an ecosystem of pre-trained models (like the NLP community) so that many tasks are a small fine-tuning step from an existing checkpoint.
  3. The recipe described here is general and extrapolates to any other clinically relevant model.

--

--

Upasana Bharadwaj
MICCAI Educational Initiative

Physician exploring the role technology can play in shaping the future of global healthcare. Deeply interested in machine learning for medical imaging.