Can Interpretable Deep Learning aid the diagnosis of Retinopathy of Prematurity?
The U.S National Eye Institute reports that each year about 14,000 to 16,000 premature infants born in the U.S are diagnosed with Retinopathy of Prematurity (ROP) and that 10% of them require medical treatment [more info]. ROP is a disease of the eye that causes the abnormal growth of the blood vessels in the retina. When not promptly treated, ROP degenerates and may cause the total detachment of the retina, thus blindness.
An indirect ophthalmoscope is used to visually inspect the retina and check ROP. Special cameras can take high resolution pictures of the retina. These pictures can be analyzed by multiple pathologists and by automated methods (particularly machine learning techniques). Moreover, they can be used to track disease evolution over time. The ROP diagnosis consists in identifying the affected zones of the retina, staging the disease on a scale from 1 to 5, and identifying symptoms of pre-plus or plus disease. Doctors determine the presence of pre-plus or plus on the basis of coexistence of clinical factors, such as increased venous dilation and arterial tortuosity. The distinction between the two diseases is very subtle, and it is often a reason for strong disagreement among experts.
Complex methods to track the eye vessels and extract handcrafted features have been used in the literature to apply machine learning algorithms and support the clinicians in the decision. Recently, Convolutional Neural Networks (CNNs) outperformed 6 of 8 ophthalmology experts in the diagnosis of the plus disease [Brown et al., 2018]. But is the CNN looking at the same clinical factors checked by the doctors? In our work about Improved interpretability for computer-aided severity assessment of Retinopathy of Prematurity [presented at SPIE Medical Imaging 2019 — Computer-Aided diagnosis ], we address this question via post-hoc model interpretation. Especially, we generate post-hoc explanations (without the need of retraining) of the CNN’s classification of the plus and pre-plus diseases.
In this post we briefly give an overview of our approach to addressing the physicians’ question:
Is the CNN paying attention to the handcrafted features frequently used by machine learning approaches?
As a first step, we collaborated with physicians to identify six handcrafted features representing characteristics of the vessels that can be extracted from the vessel segmentation images of the retina:
- the curvature mean and median of the vessel segments
- the point diameter mean
- the segment diameter mean
- the mean and median of the cumulative tortuosity index
More details about the retinal measures are in our recent paper ”Improved interpretability for computer-aided severity assesment of retinopathy of prematurity ” and in the work by Esra Ataer-Cansizoglu et al., 2015. Such retinal measures are computed on the automatic segmentation of the vessels obtained by a UNET and were previously used as handcrafted features in traditional machine learning approaches to detect the Plus disease in the images of ROP.
The activation of a layer of the CNN contains information relative to a retinal measure if we can find a Regression Concept Vector (RCV)of that measure. Given a set of inputs and retinal measures annotations, we solve least squares linear regression in the activation space. The resulting RCV represents the direction of greatest increase of the retinal measure in the space of a layer’s activation. The determination coefficient of the regression shows that some measures are much easier to regress in ROP cases with the Plus disease.
To sum up: how can Regression Concept Vectors aid the diagnosis of ROP?
RCVs give a practical framework to verify whether the CNNs features can be linked to handcrafted features, or even just clinical annotations. For instance, given a set of images (30 or more) and annotations about some clinical parameters like vessel length, dilation or curvature, RCVs show if the same parameters are meaningful to the network.
This can help doctors to formulate decisions on whether the CNN’s output should or should not be trusted for example, or to detect faulty behaviors in the network.