Using CNNs to Diagnose Diabetic Retinopathy

Dev Patel
The Startup
Published in
10 min readDec 28, 2020

Computer vision has transformed the way we look at the world’s most challenging problems (see what I did there, look).

Despite having new algorithms and model architectures come up and change the playing field, we still have a long way to go in terms accessibility, efficiency, and accuracy of these models before we decide to have all of our doctors on the GPU.

For now though, let’s look at what can be done with some stable internet connection, a computer, and PyTorch. The result: a CNN that can accurately diagnose diabetic retinopathy.

I’m not going to focus all that much on the actual scientific cause of DR but what I will talk about is the benefits of ML in diagnosing diseases and the opportunity these models present:

  1. Millions of datasets and accurate predictions: The number of MRI scans 1 doctor can analyze is futile compared to a computer and can rapidly improve speed and accuracy of diagnoses. Experience yields accuracy.
  2. Identifying trends in data over time: AI and ML algorithms can identify subtle trends in a given dataset with relation to a more expansive set of features and variables including time. Being able to detect disease progression at the right time and early on can save millions of lives.
  3. Quality access to those without doctors: Although farfetched, it is almost certain that doctors and even specialists are about to become obsolete in the next few decades due to advances in computer vision and AI. However, for those countries and communities where healthcare is already sparse, AI provides the opportunity to provide real-time care while eliminating factors like education, experience, and distance.

With that being said, disease prediction with computer vision is a difficult task and includes several issues of its own from collecting good data, fine-tuning the hyperparameters in the best fashion, and making sure patients aren’t misinformed about a disease they may not have.

Despite all of this, what you need to know is that AI is powerful and the fact that a 15 year-old kid can create a program through some code that performs similar to if not better than a trained doctor should blow your mind.

But what is Diabetic Retinopathy? A brief overview:

According to Fighting Blindness Canada, Diabetic retinopathy (DR) is the most common form of vision loss associated with diabetes. Affecting approximately 500,000 Canadians, it is the leading cause of blindness among working-age adults.

Most of these cases are due to the late onset of diabetes where patients are unable to control their blood-sugar levels can lead to blood vessel damage in the eye. As these vessels are unable to supply vital nutrients to the eye tissue, the body produces new, weaker vessels that can burst and damage the eye itself.

Credit: Magrabi Hospitals -> Comparison of the Anatomy of the Eye between a healthy and diabetic eye. Notice the spots, aneurysms, and congestion of smaller blood vessels all throughout the eye.

From here, nerve fibres can swell, micro aneurysms protrude from the vessel walls of the smaller vessels, and could lead to macular ischemia. This is a very common disease and treatments are available, but the most apparent problem stems from diagnosis and early onset of catching DR during its moderate stage.

Most diagnosis technology like fluorescein angiography, widefield imaging technology, and Optical Coherence Tomography (OCT) are available but do not provide a comprehensive view of the retina and its entirety, making it more difficult for doctors to give an accurate diagnosis in the process and, for that matter, early diagnosis.

According to the Byers Eye Institute at Stanford, 30% of DR patients are unaware that they have diabetic retinopathy until it reaches the severe state.

Credit: Modern Optometry -> Widefield image demonstrating severe non-proliferative diabetic retinopathy in a patient’s right eye (A) and proliferative diabetic retinopathy in a patient’s left eye (B).

CNNs have great potential in this problem and can significantly improve the early onset of diabetic retinopathy in seemingly healthy patients. In this part of a 2-part series, I’ll be going through the technicals of the code, the data, and the approach while giving you the basic knowledge to design your own custom CNN applications.

I want to emphasize here is that the results of the model came out to be very poor, and part 2 addresses those issues while also showing you how to improve the accuracy of the model from 23% to 98.95%.

The article shows a concrete method for developing CNNs that can be applied to any dataset as the issue was not the model, but the data itself.

If you don’t have the necessary prerequisites (PyTorch and Python, the basics of CNNs, and some basic library knowledge), I would recommend you read over these 2 articles and the second part of this series so that you’re able to gain a lot of value from this. Even still, you should be able to follow along.

The Dataset: Kaggle API and Preprocessing Dataset

To start, this is the dataset that I used for the model containing about 3662 images with 5 labels to classify each image -> No_DR: 0, Mild: 1, Moderate: 2, Severe: 3, Proliferate_DR: 4

To start, we need to import the dataset using the Kaggle API. You can use this article to guide you in the right direction, or you can figure it out and challenge yourself :) .

The commented line above needs to be used to download the dataset

Next, let’s look at the dataset in the project directory:

Notice that there is a train.csv file -> if you open it up, you’ll see that it contains the labels in integer format corresponding to the image key name. Open the /gaussian_filtered_images/gaussian_filtered_images/ and there will be 5 folders titled Mild, Moderate, No_DR, Proliferate_DR, and Severe containing their respective images.

Because we already know the corresponding integer labels with the given image labels, we need to combine all of the contents of the folder into the root folder.

To do this, just go into each folder, copy the images, and paste them in the /gaussian_filtered_images/gaussian_filtered_images/.

Then, delete the labelled folders. Make sure that the gaussian_filtered_images directory then contains 3662 images.

Now, I used this script to load in the data and preview some of the images with the corresponding output:

trainer_names_csv automatically produces a dataframe that can be indexed using iloc
This image should correspond to the image name in the directory

Here are some examples of images with their given labels for reference:

The following labels give an indication of the visual differences between the severity of diabetic retinopathy.

“Click here to read some research on the disease as a whole and more symptoms and diagnostic tools currently in use.”

Following this, we need to preprocess the data so that it is usable for the CNN. This can be done through a multitude of ways but for a custom dataset, using a dataset class can be extremely helpful.

Essentially, we have created an immutable class that can be used to format any dataset given the correct formatting. You may be wondering why I didn’t combine both the image and the label in 1 dtype -> that’s because I’m going to have to call on each individual part of the dtype either way, and it will just make it more harder to index through the given data.

The reason I chose to take this approach of using classes because it allows us to produce copies of one dataset with different transformations (one that is normalized and one that is not).

The Loader: Initializing HyperParameters for Model

Now, we can get to getting the data ready to load into the model.

The following code is used to preview the images while loading in a transformed version of the dataset.

import matplotlib.pyplot as pltfor i in range(len(image_dataset)): #indexing through all the samples in the dataset (3662)
sample = image_dataset[i] #indexes an item in the image dataset
ax = plt.subplot(1, 4, i+1) #1 row, 4 images per row
plt.tight_layout()
ax.set_title(‘Sample #{}’.format(i))
plt.imshow(sample[0])
print(image_dataset[i])ax.axis(‘off’)

if i==3: #stops the program at 3 images but you can change this
break

Here is the output of the preview below:

As you can see, each image has a size of 224 x 244 while the corresponding label to the disease is also provided (2, 4, 1, 0).

You may be wondering what the transforms.Normalize() does and to explain it simply, it scales down and reformats the data given the parameters ‘mean’ and ‘std’:

image = (image - mean) / std -> the mean is the the first parameter [0.485, 0.456, 0.406] and the std is the second [0.229, 0.224, 0.225]. What this does is that it scales down and regularizes the data to calibrate the model and also make sure the data is readable by the model.

The formula above, if done on every image in our dataset, will resize the RGB values (0, 255) to (0, 1) and scale the range down.

This offers countless benefits which are outlined here. For now though, notice that if we were to replace the image_dataset with transformed_dataset, the image would not be able to be generated as the inputs are scaled down.

You can see this in the output of the line below in loading_data.py where the values of the 224 x 224 matrix are all in the range of 0, 1:

print(image) 
Note that a) the values will not be the same because we are shuffling the dataset when initializing the train_loader and b) there are 10 labels in the first tensor because our batch_size is 10 and in relation to this, this example of the 224x224 matrix is just 1/10.

The Model: Starting the Training Process

After this, constructing the model architecture is fairly simple:

From here, we need to create a training routine while also setting up the optimizer and loss function.

The training routine is written as the following:

This should start printing out the loss and after the epochs are done, it will have completed the training routine.

NOTE FOR CNN ARCHITECTURE AND HANDLING ERRORS

Note that the parameter 9216 changes based on the previous layers. If you change any of the parameters, uncomment out the print(x.shape) and run the training routine listed below, it will throw an error. For example, I changed the 16 channel output in the third conv layer to a 20 channel output:

self.conv3 = nn.Conv2d(8, 20, 5)

It threw the following error (after uncommenting the print(x.shape) line).

Based on the shape of the convlayer that you printed, you can update the x.view with the size printed. So in our case, from 16 *24*24 to 20*24*24.

After that, you should get an error that says the shapes of the 2 matrices cannot be multiplied.

Based on the second shape (10x11520), you can change the parameter of 9216 with 11520 so that the transposed matrices can be multiplied correctly. Once you do so, comment out the x.shape line and continue with training. The changed code will be:

However, something seems off here. The loss doesn’t decrease normally nor is the accuracy any great.

Based on this, there a few options to correct any changes. Either you can change your model architecture, the learning rate, double check if your model is running correctly, change the dropout rate, or check if the model is stuck at a local minimum.

I tried all of this and none of it worked.

It seemed as if I didn’t have enough data to train the model effectively or the variation across the data was unable to do justice to the model.

With all of that, the final model on the validation set showed an accuracy of only 23%.

Now, I was stuck and unable to accurately produce an accurate model. Not all hope was lost, however. There is a solution that is although used in almost every CNN project is rarely the focus of building accurate and efficient computer vision models. That saviour, which I’ll discuss in part 2 is transfer learning. Stay tuned ( ͡• ͜ʖ ͡•)

Thanks for taking the time to read this and I hope you got something out of it. If you want to get more technical or simply reach out to me, you can find me on LinkedIn, Email, or GitHub. You can also subscribe to my newsletter here.

--

--

Dev Patel
The Startup

genomic engineering, ai, hardware | berkeley co27