Detect Eye Diseases With Pytorch

Giuseppe Minardi
The Startup
Published in
5 min readJul 1, 2020

Deep learning is part of a broader family of machine learning methods based on Artificial Neural Networks (ANN). Deep learning today is ubiquitous, it is used in different application, from image classification to speech recognition. In this blog post I’m going to show you how to build a simple neural network to detect different eye diseases from Retinal optical coherence tomography (OCT) images using pytorch.

This is the 5 assignment for the course Zero to GANs on freeCodeCam.com

The dataset

OCT is an imaging technique used to capture high-resolution cross sections of the retinas of living patients. Approximately 30 million OCT scans are performed each year, and the analysis and interpretation of these images takes up a significant amount of time.

The dataset is taken from kaggle, it’s organized into 3 folders (train, test, val) and contains subfolders for each image category: choroidal neovascularization (CNV), diabetic macular edema (DME), multiple drusen present in early AMD (DRUSEN), and normal retina with preserved foveal contour and absence of any retinal fluid/edema (NORMAL).

Load and preprocess the images

First we’re going to load all the libraries and specify the function that we will use to load our data and our model on the GPU.

Then we’re going to parse all the image in the train folder in order to create two vectors containing the mean and the standard deviation of each channel of the training images. We’re gonna use those stats to normalize the images.

Now we load the data using pytorch. Each image is center-croppped to a size of 490x490 pixels (in order to have uniform size between each image), is converted to a tensor and then normalized.

Visualize the data

Now that we have loaded and pre-processed the data we can do some data exploration.

On the left we can see the raw image, on the right we see the image after the preprocessing. Normalization centers the mean of all channels around zero, this operation helps the network to learn faster since gradients act uniformly for each channel and helps to bring out meaningful features in the image.

We plot the distribution of the labels of the train, validation, and the test set to check for label imbalance. We can see that the distribution of the labels in the validation and test set is homogeneous, while the distribution of the labels in the training set in imbalanced. There are different ways to deal with unbalanced datasets, here we tried subsampling.

Load the data

Here we load the data on the GPU using a custom loader. the batch size is needed to feed the data into the model in smaller batches since the memory of the GPU isn’t enough to hold thousands of images. The train dataset is not loaded as a whole, we used a custom sampler to subsample the data (in order to have a uniform distribution of each label) and to reduce the number of training data to 4.000 total images to speed the computation time.

The model

Finally we create our model, given our relative low sample size (and to speed up training time) we used Transfer Learning with a pretrained ResNet, removing the final fully connected layers and adding only two linear layers with a ReLu activation function. Since we’re doing a classification task our loss function will be the Cross Entropy Loss.

Training the model

A pretrained model is useful because its layers are already trained to extract features (like particular shapes, or lines, etc…) from the images, therefore weights and biases of the convolutional part of the network shouldn’t change a lot during training. On the other hand the final fully connected layers that we created are initialized with random weights. One way of dealing with this problem is freeze all the pretrained part of the network, train only the final fully connected layers and then unfreeze all the network and train it with a low learning rate. We adopted another technique: we used different learning rates for each part of the network. Deeper layers are trained with lower learning rates. In this way we train the classifier and we fine tune the pretrained network without the need of training it two times.

As an example we trained this network for 15 epochs using the Adam optimizer, a learning rate of 0.004, and a learning rate scheduler.

Scheduling the learning rate is extremely useful, since high learning rates helps the network learn faster, but they risk to miss the minimum of the loss function. On the other hand, low learning rates may be too slow.

We used a scheduler that lowers the learning rate once it stop to decrease the loss function. We plot the loss function, the learning rate, and different classification metrics over each epochs:

All the metrics get to 0.9 in just a few epochs, at the end of the training the validation accuracy is 93.75%, not bad for a simple network!

Test the model

At the end we test the model using the test set. First let’s look at the correlation matrix:

It’s almost perfect! The model can correctly classify almost all the labels, it confuses 11% of the DRUSEN images as CNV. Now lets look at the metrics of the test set:

At the end, all the metrics on the test set are 0.97, even higher than the validation set!

What to do now?

This network performed great, but it can be improved:

  • More epochs: we trained only 15 epochs due to time constraints, using more epochs can drastically improve the accuracy.
  • Early stopping: using more epochs we can risk to overfit the data and increase the generalization error, early stopping stops the training when the loss of the validation set explodes due to overfitting over the test set.
  • Change parameters: learning rate, weight decay, etc… Are all parameters that can be tuned in order to increase accuracy
  • Use data augmentation: Adding random rotations, cropping, etc… can help the model generalize better

--

--