Geek Culture
Published in

Geek Culture

Out-Of-Distribution Detection in Medical AI

Why it is a problem and a benchmark to find a solution

Introduction

As Machine Learning is poised to revolutionize healthcare, we need to pause and think about all the possible ways in which our super-performing models may not work well in practice. One key reason for this is the fact that real live data may differ from training data, namely the data that the model has learned from. And even if live data looks reasonably like training data now, there is no guarantee that it will stay so in the future.

  • The patient population changes (due to change in demographics, to health policies, etc).
  • The protocol of care changes (e.g. a parameter is measured differently, or treatment behavior is modified).
  • Due to a previously undiscovered bug in the AI tool or in the healthcare infrastructure the tool connects to.
  • Due to cyberattacks or other intentional tampering with the data.

The problem of OOD detection

Machine learning models assume that new samples are similar to data they have been trained on. More precisely, we assume that the data is independent and identically distributed. Samples that are similar to training data are considered to be in-distribution.

Approaches to OOD detection

Given the significance of the task, a variety of methods for uncertainty quantification have been proposed in recent years. Broadly speaking, we can either be uncertain about a model’s prediction or about a new sample itself. Using this intuition, we can divide OOD detectors into two groups:

  1. Discriminators express uncertainty in their own predictions.
  2. Density Estimators learn the density of training features and flag samples that are in the areas of low density.
Shades of orange and blue indicate confidence levels. For discriminator, the regions of low certainty are defined by model’s decision boundaries. For density estimators, the areas with low sample density will end up with low confidence levels.

1. Discriminators

A discriminator is a model that outputs a prediction based on sample’s features. Discriminators, such as standard feedforward neural networks or ensemble networks, can be tweaked to provide a score that indicates how certain they are in their predictions.

  • Logistic regression
  • Feed-forward neural networks
  • Temperature-scaled neural networks
  • Ensemble of neural networks
  • Bayesian neural networks
  • Monte Carlo Dropout
  • Maximum softmax probability
  • Entropy
  • Variance
  • Mutual Information

2. Density Estimators

As the name suggests, density estimators learn to estimate density of training data. To see what this means, suppose that there is some process that generates our training data — for example, a process of collecting temperature measurements for a patient. To learn a density function of this process means to learn how likely it is to measure a specific value.

  • Probabilistic PCA
  • Autoencoder
  • Variational Autoencoder
  • Local Outlier Factor
  • Neural Gaussian Process

Experiments on Real-World Public Medical Data

As described above, there is a plethora of OOD novelty detection techniques. But which ones of them work and when?

  1. Detecting corrupted features is relevant if a measurement device outputs a faulty value.
  2. Detecting clinically relevant OOD groups is relevant if a model receives a group of patients with a novel disease or demographics that is different from training data.
  3. Detecting source of a dataset is relevant if a patient data comes from a new hospital which can have different procedures.

Datasets

We used two publicly available datasets:

  • MIMIC-III (Medical Information Mart for Intensive Care) which contains information about patients admitted to critical care units that spans more than a decade. Features in this dataset include laboratory test results, diagnostic codes, or demographics.
  • eICU (electronic Intensive Care Unit) database contains data from critical care units across the United States in years 2014 and 2015. Similar to MIMIC, it contains laboratory measurements, length of stays, and demographics of patients.

Mortality prediction task

All discriminator models were trained for a binary classification task of in-hospital mortality*. That is, discriminator models are not trained on OOD detection but on the actual task they are going to be used for.

OOD Detection Experiments

1. Detecting corrupted features

Consider the following scenario: a patient has been subjected to a laboratory test but the measuring device outputs a flawed measurement with physiologically non-sensical values. Naturally, instead of getting a prediction for samples for which this malfunctioning device was used, we want these samples to be flagged as OOD.

The perturbation experiment on the eICU dataset. A random feature was scaled by a factor of 10, 100, 1000, or 10 000 and the models were compared in their ability to flag these perturbed samples. In the image, the final AUC-ROC score for OOD detection is shown. It is an average over 100 experiments with randomly selected features.

2. Detecting clinically relevant OOD groups

Another relevant scenario is observing a new group of patients that the model has not been trained on. This can be due to a new underlying disease or simply a demographic shift in the patient population.

Clinically relevant OOD groups for the MIMIC dataset. We first withhold a particular group during training and then compare the AUC-ROC scores of detecting the indicated OOD groups at test time. Diff shows how different the features of the group are compared to the features in the training data. (Average results over 5 experiments are shown).

3. Detecting new datasets

As a third experiment, we used the data source as an OOD criterion. This is also a relevant problem: changing hospital protocols can result in imprecise predictions.

Detecting samples coming from a different dataset. The models were trained on the MIMIC dataset and then presented with the samples from eICU (top row), or visa versa (bottom row). The scores show AUC-ROC of detecting the new datasets as OOD. Diff shows how different the features of the group are compared to the features in the training data. (Average results over 5 experiments are shown).

Conclusions

In this blog post we introduced the problem of OOD detection and highlighted its importance for reliable deployment of Machine Learning models in healthcare. There are many different techniques for OOD detection and uncertainty estimation and they can be broadly classified into two groups — discriminators and density estimators.

References

For further reading, please see our original paper:

https://github.com/Pacmed/ehr_ood_detection

--

--

A new tech publication by Start it up (https://medium.com/swlh).

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Pacmed

Pacmed builds decision support tools for doctors based on machine learning that makes sure patients only receive care that has proven to work for them!