Out-Of-Distribution Detection in Medical AI
By Karina Zadorozhny and Giovanni Cinà (Pacmed Labs)
This blogpost is for you if: you are interested in deploying medical AI in real-world contexts.
As Machine Learning is poised to revolutionize healthcare, we need to pause and think about all the possible ways in which our super-performing models may not work well in practice. One key reason for this is the fact that real live data may differ from training data, namely the data that the model has learned from. And even if live data looks reasonably like training data now, there is no guarantee that it will stay so in the future.
In healthcare settings, there are multiple reasons why patient data can change over time, for example:
- The patient population changes (due to change in demographics, to health policies, etc).
- The protocol of care changes (e.g. a parameter is measured differently, or treatment behavior is modified).
- Due to a previously undiscovered bug in the AI tool or in the healthcare infrastructure the tool connects to.
- Due to cyberattacks or other intentional tampering with the data.
It is therefore likely that a model in production will eventually receive data points that are different from training data.
The problem of OOD detection
Machine learning models assume that new samples are similar to data they have been trained on. More precisely, we assume that the data is independent and identically distributed. Samples that are similar to training data are considered to be in-distribution.
However, in practice, we have no guarantee that a model will only see data that it similar to data it has been trained and tested on. Samples that are not well represented in the training dataset are considered to be Out-Of-Distribution (OOD). As described above, there are many clinically relevant scenarios that can lead to changes in data distribution.
For safe deployment of machine learning models in healthcare, it is crucial that we are able to flag ‘weird’ samples, since on those samples the model’s output is not reliable. Moreover, we want to flag such OOD samples in real time, or otherwise a large amount of errors might pile up before we realize the data has changed.
To be able to tell how much we can trust a model, we would like to obtain a separate uncertainty score instead of getting only a prediction per each sample. This uncertainty can then be employed to flag OOD patients for which the prediction should not be trusted.
Approaches to OOD detection
Given the significance of the task, a variety of methods for uncertainty quantification have been proposed in recent years. Broadly speaking, we can either be uncertain about a model’s prediction or about a new sample itself. Using this intuition, we can divide OOD detectors into two groups:
- Discriminators express uncertainty in their own predictions.
- Density Estimators learn the density of training features and flag samples that are in the areas of low density.
Which technique is best at flagging OOD samples? In the rest of this section we list the methods we benchmarked on the OOD detection task, focusing on real-world medical data. You can find all the details in this paper, including detailed descriptions of all the models and metrics.
A discriminator is a model that outputs a prediction based on sample’s features. Discriminators, such as standard feedforward neural networks or ensemble networks, can be tweaked to provide a score that indicates how certain they are in their predictions.
The discriminator models we studied include:
- Logistic regression
- Feed-forward neural networks
- Temperature-scaled neural networks
- Ensemble of neural networks
- Bayesian neural networks
- Monte Carlo Dropout
There are several simple metrics that can be added to quantify uncertainty of a network in a classification setting. We used:
- Maximum softmax probability
- Mutual Information
2. Density Estimators
As the name suggests, density estimators learn to estimate density of training data. To see what this means, suppose that there is some process that generates our training data — for example, a process of collecting temperature measurements for a patient. To learn a density function of this process means to learn how likely it is to measure a specific value.
By learning density of all features of training data, density estimators can tell whether a new sample comes from the same distribution. Some density models can, in principle, also output predictions about labels by learning a joint distribution of features and targets. However, density estimators are most commonly used separately from a main prediction model as they can flag OOD samples in a model-agnostic way.
Models we included in this category are*:
- Probabilistic PCA
- Variational Autoencoder
- Local Outlier Factor
- Neural Gaussian Process
*Note that explicit density estimators give us a way to express learned data distribution — for example, as parameters of a normal distribution. In this list we included models such as Local Outlier Factor that do not have this ability.
Experiments on Real-World Public Medical Data
As described above, there is a plethora of OOD novelty detection techniques. But which ones of them work and when?
To help answer this question in the medical field, we have created a set of benchmark experiments on publicly available medical tabular data (Electronic Health Records, EHR). We implemented the models described above and designed several clinically relevant experiments that could help us find the best OOD detector(s).
Specifically, we designed three experiments to mimic clinically relevant scenarios:
- Detecting corrupted features is relevant if a measurement device outputs a faulty value.
- Detecting clinically relevant OOD groups is relevant if a model receives a group of patients with a novel disease or demographics that is different from training data.
- Detecting source of a dataset is relevant if a patient data comes from a new hospital which can have different procedures.
We used two publicly available datasets:
- MIMIC-III (Medical Information Mart for Intensive Care) which contains information about patients admitted to critical care units that spans more than a decade. Features in this dataset include laboratory test results, diagnostic codes, or demographics.
- eICU (electronic Intensive Care Unit) database contains data from critical care units across the United States in years 2014 and 2015. Similar to MIMIC, it contains laboratory measurements, length of stays, and demographics of patients.
Mortality prediction task
All discriminator models were trained for a binary classification task of in-hospital mortality*. That is, discriminator models are not trained on OOD detection but on the actual task they are going to be used for.
Given the low rates of in-hospital fatalities, mortality prediction is a highly unbalanced task. Therefore, it makes more sense to measure the AUC-ROC score rather than accuracy. The performance of models was the following:
The density models were trained without labels, only on the features of the training data.
OOD Detection Experiments
1. Detecting corrupted features
Consider the following scenario: a patient has been subjected to a laboratory test but the measuring device outputs a flawed measurement with physiologically non-sensical values. Naturally, instead of getting a prediction for samples for which this malfunctioning device was used, we want these samples to be flagged as OOD.
To mimic this scenario, we selected a random feature and then scaled it by a factor (10, 100, 1000, or 10 000). Then, we compared how well each model can flag such scaled samples. The results were averaged over 100 experiments with different randomly selected features to prevent any feature bias.
The expected trend would be to see the AUC-ROC score increase (darker square) as the scaling factor increases. For most of the discriminators, it is actually the opposite! Their ability to flag corrupted samples decreases with a larger and larger flaw in the measurement. This means that models are assigning very high certainty score to corrupted samples. Density models perform much better in this test.
2. Detecting clinically relevant OOD groups
Another relevant scenario is observing a new group of patients that the model has not been trained on. This can be due to a new underlying disease or simply a demographic shift in the patient population.
Each clinically relevant group (e.g., patients with a specific disease such as renal failure, or with a specific demographic such as sex or ethnicity) was withheld during training of the models. We then introduced these patients at test-time and compared how well they are flagged as OOD.
With the exception of newborns, which is the most distinct group, models detected the OOD groups very poorly. In this test, even the density models failed.
3. Detecting new datasets
As a third experiment, we used the data source as an OOD criterion. This is also a relevant problem: changing hospital protocols can result in imprecise predictions.
To test this, we can train the models on the MIMIC dataset and then test its uncertainty on the eICU dataset (or the other way around). This is different from a domain adaptation problem as we are not looking on the predictive performance of models but on their ability to distinguish samples coming from two different datasets.
Similar to the perturbation experiments, the density models performed better than the discriminator. It seems that they can flag samples coming from a new dataset more reliably than neural discriminators do.
In this blog post we introduced the problem of OOD detection and highlighted its importance for reliable deployment of Machine Learning models in healthcare. There are many different techniques for OOD detection and uncertainty estimation and they can be broadly classified into two groups — discriminators and density estimators.
As seen from the real-world experiments on medical tabular data, the methods are not perfect and fail to detect some types of OOD. This stresses the need for more robust novelty detection approaches.
Our hope is that by providing an open source code for the experiments and using publicly available datasets, our experimental framework can serve as an easy-to-use benchmark for new techniques.
For further reading, please see our original paper:
Ulmer, D., Meijerink, L., and Cinà, G. Trust Issues: Uncertainty Estimation does not Enable Reliable OOD Detection on Medical Tabular Data. In Proceedings of the Machine Learning for Health NeurIPS Workshop, volume136, pp. 341–354, 2020.
And in this paper, we provide a theoretical explanation of why neural discriminators systematically fail at the perturbation task:
Ulmer, D. and Cinà, G. Know Your Limits: Uncertainty Estimation with ReLU Classifiers Fails at Reliable OOD Detection.arXiv:2012.05329, 2021
All code used for the experiments and models described here can be found at: