How fast is fast.ai? Building a world-class ophthalmologist in an evening

Alex Federation
The Startup
Published in
5 min readMay 28, 2019

Background

I’m working through the fast.ai deep learning course this summer, and the instructor Jeremy often brags (in a good way, like a proud dad) about how students in the course can quickly get near-world-class models built with the fast.ai library. I wanted to put this idea to the test and decided to apply approaches from the first lecture only to a highly-cited cell paper from last year.

As a biologist, a Cell paper can be the highlight of one’s career. Cell publishes the best work from across all fields of biology, and the Kermany et. al. paper from last year represented a breakthrough in applying deep learning to ocular pathology. Briefly, they used an imaging technology called retinal optical coherence tomography (OCT) to image healthy and diseased eyes (3 possible conditions) and employed a panel of expert ophthalmologists to classify them. They had multiple experts classify a subset of photos so they could get a measure of how accurate the humans were as well.

a colorized OCT scan

Interestingly, they took a similar approach to what we’ve been working on — transfer learning with a convolutional neural network. Let’s see how our simple approach compares.

%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *

Load the data

Luckily, the data are already organized and released on the kaggle platform. They have a nice, pip-able API for accessing their data. You need an API token, and all the details are explained well in the docs.

pip install kaggle
kaggle datasets download -d paultimothymooney/kermany2018

Pre-process the data

Here, I used the API provided by fast.ai to get the data organized into a training and validation set, apply the standard transforms and normalize the data using statistics from imagenet, our pre-trained model.

src = (ImageList.from_folder(path)
.split_by_folder(train='train', valid='test')
.label_from_folder())
tfms = get_transforms()
data = (src.transform(tfms, size=128)
.databunch().normalize(imagenet_stats))

And we can take a look at the images and labels to make sure all looks good.

data.show_batch(rows=3, figsize=(12,9))

Train the model

We’re trying to classify the eyes into one of 4 categories.

  1. Healthy
  2. Choroidal Neovascularization (blood vessel formation in the eye, related to macular degeneration)
  3. Diabetic macular edema (fluid in the retina)
  4. Drusen (fat deposits in the retina)

We’re using the Resnet-34 convolutional neural network architecture pre-trained with the imagenet dataset.

learn = cnn_learner(data, models.resnet34, metrics=[accuracy, error_rate])
learn.fit_one_cycle(4)
4 epochs using resnet-34
learn.save('stage-1')

97% before fine-tuning? That’s not a bad start. How does our confusion matrix look? The biggest difficulty the model is having is between Drussen and CNV, which is the same confusion that was most common in the human experts, an encouraging sign.

Stepping it up — resnet50

The final strategy presented in Lecture 1 was increasing the size of the network up to a 50-layer CNN and increasing the image size. I’ll try that here and let it go for 4 epochs. Keep in mind that in the Cell paper, they trained for 100 epochs to a 96.6% accuracy.

First, we use the learning rate finder to find the learning rate where we’re decreasing loss at the fastest rate, then use that to train the last few layers of the network (8 epochs). Then we’ll unfreeze the model, train for a few more epochs to fine-tune the whole model (4 more epochs) and see how we’re doing.

data = (src.transform(tfms, size=256)
.databunch(bs=32).normalize(imagenet_stats))
learn = cnn_learner(data, models.resnet50, metrics=[accuracy, error_rate])
learn.lr_find()
learn.recorder.plot()
lr = 0.001
learn.fit_one_cycle(4, slice(lr))
learn.save('rn50-stage-1')

And lastly, unfreeze the model and tune the entire set of parameters.

learn.unfreeze()
learn.fit_one_cycle(4, slice(5e-6, lr/5))
learn.save('rn50-stage-2')

Conclusion

We can get ~99% accuracy using the same dataset as Kermany et. al. with the fast.ai workflow, even for a seemingly complex problem like retinal pathology. The model in the paper is ~96% accurate and the best human expert was in the 99% range, so we’re definitely in good company with this result.

My take-aways are twofold: 1. The fast.ai library works exceptionally well out of the box 2. Deep learning is an incredibly fast-moving field

This exercise is not meant to take anything away from the original authors. The task of assembling and annotating the dataset is an impressive feat in itself, and their model is incredibly good at triaging cases with a low false-negative rate, which is a key metric for clinical use.

What this exercise does highlight is one of the barriers to entry for someone trying to get into deep learning coming from another scientific field. Things are moving so quickly, it’s intimidating to catch up on the basics knowing that in 18 months a Cell paper can get surpassed by a python library. So for me, it has been really encouraging to take the library and do something useful, while still catching up on the basics as I go.

Moving forward, it has been shown that medical imaging models are highly sensitive to the instrumentation and hospital where the images were generated (ref). I’m working with a friend at Baylor Medical Center to gather an additional set of OCT images as a test set to compare the two models with data from an unrelated institution. The results of that experiment will be a good test for both models and I’ll be sure to provide an update.

Adapted from my website

--

--

Alex Federation
The Startup

Modeling biological networks and how drugs disrupt them | Altius Institute and UW | Climbing, running, skiing the PNW mountains