Diabetic Retinopathy Detection with ResNet50

Shivang Ganjoo
2 min readApr 27, 2018

In January, I started working on my first Computer Vision project using Keras library with Tensorflow as its backend.

The project was about detecting Diabetic Retinopathy. In this disease high blood sugar levels cause damage to blood vessels in the retina.These blood vessels can swell and leak. Or they can close, stopping blood from passing through. Sometimes abnormal new blood vessels grow on the retina. All of these changes make anyone blind.

Clinicians take close to two days in diagnosing this disease. Many doctors cannot determine the stage of this disease that accurately.

In this post I’ll show step by step approach for making automated diagnoses of this disease.

Pre-Processing

We first need to load all the images either through an ImageDataGenerator or through cv2 library.

Then all the images will be resized, rotated, sheared and zoomed. This allows our model to predict the correct stage of the disease even when the test image is a little bit different. This also helps reduce over-fitting.

Lastly, we will split the dataset into 9:1 ratio.

Images should be of very high resolution(4000 x 3000 approximately).

Training

Two things can be done for training. Either Transfer Learning or training from scratch can be done here. I found that training from scratch and adding two fully connected layers with dropout gave better results. I used categorical_crossentropy as the loss function and adam as the model’s optimizer.

Do hyper parameter tuning by using any approach like grid search, Bayesian optimization or the Bandit approach. Then we can fit the model. Many more concepts like TTA and differential learning can be used for more accuracy.

For transfer learning, initially train the newly added layers. Unfreeze all the layers and retrain the whole network.

Testing

For testing image(s), load and pre-process them. They should be of the same size as the training images. Use numpy’s expand dimensions method as keras expects another dimension at prediction which is the size of each batch.

Use argmax function to get the predicted class.

Resources

Bayesian Optimization -> https://www.youtube.com/watch?v=cWQDeB9WqvU

Bandit Approach -> http://fastml.com/tuning-hyperparams-fast-with-hyperband/

For the starter source code visit my github account -> https://github.com/lightsalsa251

Disclaimer

This is a very dangerous disease so I strongly recommend everyone to see a doctor as till now no model has proven to be reliable enough. Lack of data also has stopped automation of this disease’s diagnosis.

--

--