How deep learning can help doctors prevent blindness in diabetes

Aniket Mishrikotkar
7 min readDec 18, 2020

One day, diagnosing serious diseases may be as easy as taking a temperature or checking blood pressure. But in the near term, millions of diabetics could keep their vision thanks to an AI algorithm helping doctors quickly spot diabetic retinopathy.

Retina image

Table of contents

  1. Introduction
  2. Data
  3. Evaluation Metric
  4. EDA and Image Processing
  5. TensorFlow Input Pipeline
  6. Model
  7. Error Analysis
  8. References
  9. Future Work

1. Introduction

The covid-19 pandemic is stretching hospital resources to the breaking point in many countries in the world. It’s no surprise that many people hope AI could speed up patient screening and ease the strain on clinical staff. This case study is not about covid-19 but diabetes which is a growing concern.

With 70 million people with diabetes, India has a growing concern with diabetic retinopathy. The disease creates damage or an abnormal change in the tissue in the back of the retina that can lead to total blindness, and 18 percent of diabetic Indians already have the ailment. With 415 million diabetics at risk for blindness worldwide, the disease is a global concern.

But the good news is that permanent vision loss is not inevitable. If caught early, the disease can be treated; if not, it can lead to total blindness.

Examples of retinal fundus photographs that are taken to screen for DR. The image on the left is of a healthy retina (A), whereas the image on the right is a retina with referable diabetic retinopathy (B) due a number of hemorrhages (red spots) present.

One of the common ways to detect diabetic retinopathy is to have a specialist examine the pictures of the back of the eye and determine the disease’s presence and it’s severity. Severity is determined by the type of damage present.

Specialized training is required to interpret these photographs.
Recent advances in Machine Learning and Computer Vision can improve the DR screening process. Deep Learning algorithms can interpret signs of DR in the retinal photographs, helping doctors screen more patients.

2. Data

The data is obtained from the Kaggle competition APTOS 2019 Blindness Detection. The dataset contains a large set of retina images taken using fundus photography under a variety of lighting conditions. There are a total of 3662 retina images in the dataset. A clinician has rated each image on a scale of 0 to 4.

0 — No DR, 1 — Mild, 2 — Moderate, 3 — Severe, 4 — Proliferative DR

Few images from the dataset

3. Evaluation Metric

The evaluation metric for a Multi-class Classification problem could be a classification accuracy or an F-score. Kaggle competition had a defined evaluation metric — Quadratic weighted kappa.

Quadratic weighted kappa is a measurement of agreement that ranges from 0 (random) to 1 (perfect agreement). There is a better explanation available here.

Cohen’s kappa

4. EDA and Image Processing

The dataset is imbalanced. There are a lot more images for a healthy retina. Only 5% of the total images belong to class 3 (severe DR).

Class distribution

To correct for data imbalance, we will use class weighting.

Class weighting
Weight for class 0: 1.01
Weight for class 1: 4.95
Weight for class 2: 1.83
Weight for class 3: 9.49
Weight for class 4: 6.21

Let’s use TSNE visualization with a perplexity of 40. Class 0 is separable but the classes are not.

5. TensorFlow Input Pipeline

1. We are defining the key configuration parameters.

Configurations

2. Load the data

The tf.data API enables you to build complex input pipelines from simple, reusable pieces. To construct the dataset we are using tf.data.Dataset.from_tensor_slices(). We will transform this dataset into a new one by chaining methods.

Load the data using tf.data API
Training images count:  2929
Validating images count: 733

As we have 5 labels, we will convert these into a one-hot tensor. For example, 2 will be converted to [0, 0, 1, 0, 0]. Also, we have to map each filename to its label. We can do this using the following methods.

Create (image, label) pairs

Let’s visualize the shape of the image and label.

Image shape: (320, 320, 3)
Label: [1. 0. 0. 0. 0.]

Let’s use buffered prefetching so we can get data from disk without having I/O getting blocked. We are using tf.image API for data augmentation.

Data augmentation

Visualize the dataset after image augmentation.

Visualize augmented images

6. Model

1. Define Callbacks

The checkpoint callback saves the best weights of the model, so that the next time we want to use the model, we do not have to train the model. The early stopping callback is used to stop the training process if the model starts overfitting or becomes stagnant. Reduce LR on plateau callback is used to reduce the learning rate when the metric stops improving.

Callbacks

2. Transfer learning for pre-trained weights

We are initializing the model with pre-trained ImageNet weights.

For our use case, we have used accuracy as the metric which tells us the fraction of correct predictions. Since there are 5 classes, we are using categorical crossentropy as the loss function. We have also specified class weights as we discussed earlier.

Let’s plot the model accuracy and loss for the train and validation set. We can see that accuracy of our model is 83%. We can see our accuracy on validation data is lower than the train data which indicates overfitting.

Loss and Accuracy

The confusion matrix indicates classes 1, 3, and 4 are being misclassified as class 2. Maybe our model has not been able to detect the spots/hemorrhage that are present in classes 3 and 4(severe cases of DR).

Confusion matrix for validation set

3. High Resolution Network

HRNet is recently developed for human pose detection but can be used in Image Classification, Object Detection, etc. Code is provided by the researchers here. Official code is written with PyTorch. We had to rewrite the code in TensorFlow. HRNet maintains high-resolution representations through the whole process of connecting high-to-low resolution convolutions in parallel and produces strong high-resolution representations by repeatedly conducting fusions across parallel convolutions. The research paper is linked here.

High Resolution Network Architecture

For Image Classification, we need to replace the head with a softmax layer. You can find the code in my GitHub repository. The results of this model were not encouraging. We got an accuracy of 68%. We experimented with different loss functions and optimizers but we were not able to improve the performance.

7. Error Analysis

Let’s see the classification report. In our case, the Recall score for classes 3 and 4 is very low which means that we are misidentifying these classes where cost associated is very high. We need to improve our model and recall scores for each class.

Classification report

Let’s visualize the image with actual and predicted labels.

Actual and predicted labels with score in percent

We can see the image and the model prediction with the probability score for each class.

Visualize image and its predicted label with score
Actual Label - Proliferative DR
Predicted Label - Moderate
Image and prediction score

The inference time is 0.45 seconds and the rate is 0.01 predictions per second in Jupyter notebook. Streamlit web app takes more time on local machine. It took 10–20 seconds per prediction.

Streamlit web app to upload an image and get the prediction

References

Future Work

  • Use data from other sources such as eyePACS/Messidor which could further improve our accuracy.
  • Set up a continuous integration system for our codebase, which will check the functionality of the code and evaluate the model about to be deployed.
  • Package up the prediction system as a REST API and deploy it as a Docker container as a serverless function to Amazon Lambda.

You can connect with me on LinkedIn. You can view the code in this GitHub repository.

--

--

Aniket Mishrikotkar

Writing. Coding. Reading. Sometimes doing all of this together.