How fast is fast.ai? Building a world-class ophthalmologist in an evening
Background
I’m working through the fast.ai deep learning course this summer, and the instructor Jeremy often brags (in a good way, like a proud dad) about how students in the course can quickly get near-world-class models built with the fast.ai library. I wanted to put this idea to the test and decided to apply approaches from the first lecture only to a highly-cited cell paper from last year.
As a biologist, a Cell paper can be the highlight of one’s career. Cell publishes the best work from across all fields of biology, and the Kermany et. al. paper from last year represented a breakthrough in applying deep learning to ocular pathology. Briefly, they used an imaging technology called retinal optical coherence tomography (OCT) to image healthy and diseased eyes (3 possible conditions) and employed a panel of expert ophthalmologists to classify them. They had multiple experts classify a subset of photos so they could get a measure of how accurate the humans were as well.
Interestingly, they took a similar approach to what we’ve been working on — transfer learning with a convolutional neural network. Let’s see how our simple approach compares.
%reload_ext autoreload
%autoreload 2
%matplotlib inlinefrom fastai.vision import *
Load the data
Luckily, the data are already organized and released on the kaggle platform. They have a nice, pip-able API for accessing their data. You need an API token, and all the details are explained well in the docs.
pip install kaggle
kaggle datasets download -d paultimothymooney/kermany2018
Pre-process the data
Here, I used the API provided by fast.ai to get the data organized into a training and validation set, apply the standard transforms and normalize the data using statistics from imagenet, our pre-trained model.
src = (ImageList.from_folder(path)
.split_by_folder(train='train', valid='test')
.label_from_folder())
tfms = get_transforms()
data = (src.transform(tfms, size=128)
.databunch().normalize(imagenet_stats))
And we can take a look at the images and labels to make sure all looks good.
data.show_batch(rows=3, figsize=(12,9))
Train the model
We’re trying to classify the eyes into one of 4 categories.
- Healthy
- Choroidal Neovascularization (blood vessel formation in the eye, related to macular degeneration)
- Diabetic macular edema (fluid in the retina)
- Drusen (fat deposits in the retina)
We’re using the Resnet-34 convolutional neural network architecture pre-trained with the imagenet dataset.
learn = cnn_learner(data, models.resnet34, metrics=[accuracy, error_rate])
learn.fit_one_cycle(4)
learn.save('stage-1')
97% before fine-tuning? That’s not a bad start. How does our confusion matrix look? The biggest difficulty the model is having is between Drussen and CNV, which is the same confusion that was most common in the human experts, an encouraging sign.
Stepping it up — resnet50
The final strategy presented in Lecture 1 was increasing the size of the network up to a 50-layer CNN and increasing the image size. I’ll try that here and let it go for 4 epochs. Keep in mind that in the Cell paper, they trained for 100 epochs to a 96.6% accuracy.
First, we use the learning rate finder to find the learning rate where we’re decreasing loss at the fastest rate, then use that to train the last few layers of the network (8 epochs). Then we’ll unfreeze the model, train for a few more epochs to fine-tune the whole model (4 more epochs) and see how we’re doing.
data = (src.transform(tfms, size=256)
.databunch(bs=32).normalize(imagenet_stats))
learn = cnn_learner(data, models.resnet50, metrics=[accuracy, error_rate])learn.lr_find()
learn.recorder.plot()lr = 0.001
learn.fit_one_cycle(4, slice(lr))
learn.save('rn50-stage-1')
And lastly, unfreeze the model and tune the entire set of parameters.
learn.unfreeze()
learn.fit_one_cycle(4, slice(5e-6, lr/5))
learn.save('rn50-stage-2')
Conclusion
We can get ~99% accuracy using the same dataset as Kermany et. al. with the fast.ai workflow, even for a seemingly complex problem like retinal pathology. The model in the paper is ~96% accurate and the best human expert was in the 99% range, so we’re definitely in good company with this result.
My take-aways are twofold: 1. The fast.ai library works exceptionally well out of the box 2. Deep learning is an incredibly fast-moving field
This exercise is not meant to take anything away from the original authors. The task of assembling and annotating the dataset is an impressive feat in itself, and their model is incredibly good at triaging cases with a low false-negative rate, which is a key metric for clinical use.
What this exercise does highlight is one of the barriers to entry for someone trying to get into deep learning coming from another scientific field. Things are moving so quickly, it’s intimidating to catch up on the basics knowing that in 18 months a Cell paper can get surpassed by a python library. So for me, it has been really encouraging to take the library and do something useful, while still catching up on the basics as I go.
Moving forward, it has been shown that medical imaging models are highly sensitive to the instrumentation and hospital where the images were generated (ref). I’m working with a friend at Baylor Medical Center to gather an additional set of OCT images as a test set to compare the two models with data from an unrelated institution. The results of that experiment will be a good test for both models and I’ll be sure to provide an update.
Adapted from my website