Transfer learning in retinal fundus image made easy with EasyTorch: A library on top of PyTorch.

Sraashis
7 min readMar 12, 2022

--

Deep learning learns underlying biases in data to understand salient features that make up the essence of some data. Given enough data, it performs very well, but we have many scenarios where having enough training data doesn't even make sense. Today, we will discuss a similar case of retinal vessel segmentation in the color fundus image. We will understand why we can't have enough training data, and why the only way is to do transfer learning.

A sample fundus image and its vessel segmentation from the DRIVE[1] dataset.

Retinal fundus images are 2D images taken of the background of an eye. It is a non-invasive procedure and exposes the artery-veins, retina, optic nerve head, optic disk, cup, etc. of the fundus image. An ophthalmologist manually observes such images and provides a further diagnosis. Sometimes such inspection includes laborious measurements like Artery Vein ratio(Ratio of diameter). Cup to Disk ratio. Unnatural deposits and patches like hemorrhages, exudates, and more.

As you can see, creating manual segmentation involves pixel-by-pixel labeling by an ophthalmologist. So for a new, real-world fundus image generated in a hospital, we can use transfer-learning to learn vessel segmentation from the existing dataset that has manual segmentation. Transfer learning is an art, where deep learning is the artist, and data augmentation is the color palette. Let's call the data with ground truth training data and the data we want to transfer-learn to target data. It is a challenging task as we have a wide range of fundus cameras. Some produce clear HD pictures, whereas others low-quality. The data acquisition environment also plays a vital role in terms of brightness and exposure. Thus, crafting a well-suited augmentation scheme is of utmost. This article will explain how to use the publicly available dataset as mentioned in the GitHub repo and transfer it to another dataset called DDR[7]. There are around 200–300 images with vessel segmentation manual, whereas DDR consists of ~13k images. Each image is of different size, lighting, contrast, exposure, and field of view making it is a perfect transfer learning task. So let's make a plan.

  1. Installation:

pip install easytorch

git clone https://github.com/sraashis/retinal-fundus-transfer

2. CNN Model

We will use the popular U-Net[8], which takes a fixed size patch (around ~400 x 400, can be changed) of an image, and slides over to cover the entire image.

3. Data preparation

We will be using six datasets available publicly(or can be acquired from the author for research purposes): DRIVE[1], STARE[2], CHASE-DB[3], HRF[4], IOSTAR[5], and AV-WIDE[6]

One of the best things about easytorch is, it lets you create independent dataset specifications and automatically merges them internally to use them as a single dataset. So technically one can create as many specifications for as many datasets needed. A full list of data specifications is in this file. An example of DRIVE dataset specification:

import os.sep as sepdef get_label_drive(file_name):
return file_name.split(‘_’)[0] + ‘_manual1.gif’
def get_mask_drive(file_name):
return file_name.split(‘_’)[0] + ‘_mask.gif’
DRIVE = {
‘name’: f’DRIVE’,
‘patch_shape’: (388, 388), # For U-Net
‘patch_offset’: (300, 300), # For U-Net
‘expand_by’: (184, 184), # For U-Net
‘data_dir’: ‘DRIVE’ + sep + ‘images’,
‘label_dir’: ‘DRIVE’ + sep + ‘manual’,
‘mask_dir’: ‘DRIVE’ + sep + ‘mask’,
‘label_getter’: get_label_drive,
‘mask_getter’: get_mask_drive,
‘resize’: (896, 896),
‘thr_manual’: 50
}

This specification applies resizing(896 x 896) to an image, mask, and ground truth. Resizing binary images could yield a non-binary value, so the thr_manual specifies a threshold to binarize it. Some of the basic augmentations features like resizing, flipping adding noise are integrated into the library itself. But it is super to add new ones.

Similarly, we can create specifications for our target dataset as:

import os.sep as sep

DDR_TRAIN = {
"name": "DDR_train",
'patch_shape': (388, 388),
'patch_offset': (300, 300),
'expand_by': (184, 184),
"data_dir": "DDR" + sep + "DR_grading" + sep + "train",
"extension": "jpg",
"bbox_crop": True,
'resize': (896, 896),
'thr_manual': 50
}

Argument “bbox_crop”: True crops the DDR images to nicely center the Field of View. It does by creating a bounding box with OpenCV in the Field of view and only keeping that region. It applies it to all images used like ground truth, and masks out-of-the-box:

Result after applying bbox_crop in one of the DDR[7] images.

Note: We need to make sure the training data and our target data match in terms of resolution for the best result. So it is better to train with multiple resize options like (950, 950) (896, 896), (800, 800).

Please visit EasyTorch to see how we maximize GPU utilization with disk caching and multi worker pre-processing(Different that PyTorchs’ number_of_workers in dataloader).

The dataset has the following signature. A concrete implementation is found here:

from easytorch import ETDataset

class MyDataset(ETDataset):
def load_index(self, dataset_name, file):
# This method is parallelalized by default
data_spec = self.dataspecs[dataset_name]

self.indices.append([dataset_name, file])

def __getitem__(self, index):
dataset_name, file = self.indices[index]
dataspec = self.dataspecs[dataset_name]

image = # Todo # Load file/Image.
label = # Todo # Load corresponding label.

# Extra preprocessing, if needed.
# Apply transforms, if needed.

return image, label

Now, once we run specifying correct parameters and configuration it trains on training data, automatically picks the best model(based on F1 Score), and generates vessel segmentation for the target dataset.

4. Finally, to run

Case 1: Run a working example on DDR dataset using two datasets(DRIVE and WIDE) for transfer learning.

python main.py -ph train -data datasets --training-datasets DRIVE STARE --target-datasets DDR_train -spl 0.75 0.25 0 -b 8 -nw 6 -lr 0.001 -e 501 -pat 101 -rcw True

Case 2: Use more datasets as below.

python main.py -ph train -data <path to your dataset> --training-datasets DRIVE CHASEDB HRF IOSTAR STARE --target-datasets DDR_train -spl 0.75 0.25 0 -b 8 -nw 6 -lr 0.001 -e 501 -pat 101 -rcw True

This code uses easytorch framework and inherits some default args. Consult easytorch repo for details. But worry not, I will explain each of these.

  • -ph train: specifies which phase like train, test(for inference).
  • -data datasets: Path to your datasets so that you can run this code anywhere your data is. Just need to point to your datasets(Check datasets) folder for an example.
  • --training-dataset … : Which datasets to use for transfer learning from the specifications in dataspecs directory. A single model will be trained using these datasets.
  • --target-datasets … : After getting the best model, which dataset to use it on to generate vessel segmentation results.
  • -spl 0.75 0.25 0: Split ratio for training dataset in order train, validation, test. We dont need test data for this transfer learning. We need a validation set to pick the best model.
  • -nw 6 : num of workers
  • -lr 0.001 : Learning rate
  • -e 501: Epochs
  • -pat 101: patience to stop training. If the model does not improve in the previous 101 epochs, stop the training.
  • -rcw True : Stochastic weights scheme to improve prediction on fainter vessels as detailed in the paper below[9].
Segmentation results in the DDR dataset after transfer learning.

Thanks! Please cite[9] if you use this library for your research, and respective datasets if you use them.

References:

  1. DRIVE Dataset, J. Staal, M. Abramoff, M. Niemeijer, M. Viergever, and B. van Ginneken, “Ridge based vessel
    segmentation in color images of the retina,” IEEE Transactions on Medical Imaging 23, 501–509 (2004)
  2. STARE Dataset, A. D. Hoover, V. Kouznetsova, and M. Goldbaum, “Locating blood vessels in retinal images by piecewise
    threshold probing of a matched filter response,” IEEE Transactions on Med. Imaging 19, 203–210 (2000)
  3. CHASE DB: Fraz, M. M., Remagnino, P., Hoppe, A., Uyyanonvara, B., Rudnicka, A. R., Owen, C. G., & Barman, S. A. (2012). An ensemble classification-based approach applied to retinal blood vessel segmentation. IEEE transactions on bio-medical engineering, 59(9), 2538–2548. https://doi.org/10.1109/TBME.2012.2205687
  4. HRF Dataset: Budai, A., Bock, R., Maier, A., Hornegger, J., & Michelson, G. (2013). Robust vessel segmentation in fundus images. International journal of biomedical imaging, 2013, 154860. https://doi.org/10.1155/2013/154860
  5. IOSTAR Dataset: J. Zhang, B. Dashtbozorg, E. Bekkers, J. P. W. Pluim, R. Duits and B. M. ter Haar Romeny, “Robust Retinal Vessel Segmentation via Locally Adaptive Derivative Frames in Orientation Scores,” in IEEE Transactions on Medical Imaging, vol. 35, no. 12, pp. 2631–2644, Dec. 2016, doi: 10.1109/TMI.2016.2587062.
  6. AV-WIDE Dataset: Estrada, R., Tomasi, C., Schmidler, S. C., & Farsiu, S. (2015). Tree Topology Estimation. IEEE transactions on pattern analysis and machine intelligence, 37(8), 1688–1701. https://doi.org/10.1109/TPAMI.2014.2382116
  7. DDR Dataset: Tao Li, Yingqi Gao, Kai Wang, Song Guo, Hanruo Liu, & Hong Kang (2019). Diagnostic Assessment of Deep Learning Algorithms for Diabetic Retinopathy Screening. Information Sciences, 501, 511–522
  8. Architecture used, O. Ronneberger, P. Fischer, and T. Brox, “U-net: Convolutional networks for biomedical image
    segmentation,” in MICCAI, (2015)
  9. Khanal, A., & Estrada, R. (2020). Dynamic Deep Networks for Retinal Vessel Segmentation. Frontiers in Computer Science, 2. doi:10.3389/fcomp.2020.00035

--

--