How deep learning can help doctors prevent blindness in diabetes
One day, diagnosing serious diseases may be as easy as taking a temperature or checking blood pressure. But in the near term, millions of diabetics could keep their vision thanks to an AI algorithm helping doctors quickly spot diabetic retinopathy.
Table of contents
- Introduction
- Data
- Evaluation Metric
- EDA and Image Processing
- TensorFlow Input Pipeline
- Model
- Error Analysis
- References
- Future Work
1. Introduction
The covid-19 pandemic is stretching hospital resources to the breaking point in many countries in the world. It’s no surprise that many people hope AI could speed up patient screening and ease the strain on clinical staff. This case study is not about covid-19 but diabetes which is a growing concern.
With 70 million people with diabetes, India has a growing concern with diabetic retinopathy. The disease creates damage or an abnormal change in the tissue in the back of the retina that can lead to total blindness, and 18 percent of diabetic Indians already have the ailment. With 415 million diabetics at risk for blindness worldwide, the disease is a global concern.
But the good news is that permanent vision loss is not inevitable. If caught early, the disease can be treated; if not, it can lead to total blindness.
One of the common ways to detect diabetic retinopathy is to have a specialist examine the pictures of the back of the eye and determine the disease’s presence and it’s severity. Severity is determined by the type of damage present.
Specialized training is required to interpret these photographs.
Recent advances in Machine Learning and Computer Vision can improve the DR screening process. Deep Learning algorithms can interpret signs of DR in the retinal photographs, helping doctors screen more patients.
2. Data
The data is obtained from the Kaggle competition APTOS 2019 Blindness Detection. The dataset contains a large set of retina images taken using fundus photography under a variety of lighting conditions. There are a total of 3662 retina images in the dataset. A clinician has rated each image on a scale of 0 to 4.
0 — No DR, 1 — Mild, 2 — Moderate, 3 — Severe, 4 — Proliferative DR
3. Evaluation Metric
The evaluation metric for a Multi-class Classification problem could be a classification accuracy or an F-score. Kaggle competition had a defined evaluation metric — Quadratic weighted kappa.
Quadratic weighted kappa is a measurement of agreement that ranges from 0 (random) to 1 (perfect agreement). There is a better explanation available here.
4. EDA and Image Processing
The dataset is imbalanced. There are a lot more images for a healthy retina. Only 5% of the total images belong to class 3 (severe DR).
To correct for data imbalance, we will use class weighting.
Weight for class 0: 1.01
Weight for class 1: 4.95
Weight for class 2: 1.83
Weight for class 3: 9.49
Weight for class 4: 6.21
Let’s use TSNE visualization with a perplexity of 40. Class 0 is separable but the classes are not.
5. TensorFlow Input Pipeline
1. We are defining the key configuration parameters.
2. Load the data
The tf.data API enables you to build complex input pipelines from simple, reusable pieces. To construct the dataset we are using tf.data.Dataset.from_tensor_slices(). We will transform this dataset into a new one by chaining methods.
Training images count: 2929
Validating images count: 733
As we have 5 labels, we will convert these into a one-hot tensor. For example, 2 will be converted to [0, 0, 1, 0, 0]. Also, we have to map each filename to its label. We can do this using the following methods.
Let’s visualize the shape of the image and label.
Image shape: (320, 320, 3)
Label: [1. 0. 0. 0. 0.]
Let’s use buffered prefetching so we can get data from disk without having I/O getting blocked. We are using tf.image API for data augmentation.
Visualize the dataset after image augmentation.
6. Model
1. Define Callbacks
The checkpoint callback saves the best weights of the model, so that the next time we want to use the model, we do not have to train the model. The early stopping callback is used to stop the training process if the model starts overfitting or becomes stagnant. Reduce LR on plateau callback is used to reduce the learning rate when the metric stops improving.
2. Transfer learning for pre-trained weights
We are initializing the model with pre-trained ImageNet weights.
For our use case, we have used accuracy as the metric which tells us the fraction of correct predictions. Since there are 5 classes, we are using categorical crossentropy as the loss function. We have also specified class weights as we discussed earlier.
Let’s plot the model accuracy and loss for the train and validation set. We can see that accuracy of our model is 83%. We can see our accuracy on validation data is lower than the train data which indicates overfitting.
The confusion matrix indicates classes 1, 3, and 4 are being misclassified as class 2. Maybe our model has not been able to detect the spots/hemorrhage that are present in classes 3 and 4(severe cases of DR).
3. High Resolution Network
HRNet is recently developed for human pose detection but can be used in Image Classification, Object Detection, etc. Code is provided by the researchers here. Official code is written with PyTorch. We had to rewrite the code in TensorFlow. HRNet maintains high-resolution representations through the whole process of connecting high-to-low resolution convolutions in parallel and produces strong high-resolution representations by repeatedly conducting fusions across parallel convolutions. The research paper is linked here.
For Image Classification, we need to replace the head with a softmax layer. You can find the code in my GitHub repository. The results of this model were not encouraging. We got an accuracy of 68%. We experimented with different loss functions and optimizers but we were not able to improve the performance.
7. Error Analysis
Let’s see the classification report. In our case, the Recall score for classes 3 and 4 is very low which means that we are misidentifying these classes where cost associated is very high. We need to improve our model and recall scores for each class.
Let’s visualize the image with actual and predicted labels.
We can see the image and the model prediction with the probability score for each class.
Actual Label - Proliferative DR
Predicted Label - Moderate
The inference time is 0.45 seconds and the rate is 0.01 predictions per second in Jupyter notebook. Streamlit web app takes more time on local machine. It took 10–20 seconds per prediction.
References
Future Work
- Use data from other sources such as eyePACS/Messidor which could further improve our accuracy.
- Set up a continuous integration system for our codebase, which will check the functionality of the code and evaluate the model about to be deployed.
- Package up the prediction system as a REST API and deploy it as a Docker container as a serverless function to Amazon Lambda.
You can connect with me on LinkedIn. You can view the code in this GitHub repository.