Out-of-Distribution Detection in Deep Neural Networks

Neeraj Varshney
Analytics Vidhya
Published in
11 min readDec 25, 2020

Making AI systems Robust and Reliable

Deep neural networks are often trained with closed-world assumption i.e the test data distribution is assumed to be similar to the training data distribution. However, when employed in real-world tasks, this assumption doesn’t hold true leading to a significant drop in their performance. Though this performance drop is acceptable for tolerant applications like product recommendations, it is dangerous to employ such systems in intolerant domains like medicine and home robotics as they can cause serious accidents. An ideal AI system should generalize to Out-of-Distribution (OOD) examples whenever possible and flag the ones that are beyond its capability to seek human intervention. In this article, we’ll dive deeper into the concept of OOD and explore various techniques for its detection. This is a self-contained article yet expects familiarity with the basic Machine Learning concepts.

Outline:

  1. A bit on OOD
    — Why OOD Detection is important?
    — Why Models have OOD brittleness?
    — Types of Generalizations
    — Plausible Reasons for Higher Robustness of pre-trained models (like BERT) than Traditional Models
    — Other Related Problems
  2. Approaches to Detect OOD instances
    — Maximum Softmax Probability
    — Ensembling of Multiple Models
    — Temperature Scaling
    — Training a Binary Classification model as a Calibrator
    — Monte-Carlo Dropout

Check out my publications on the topic of robustness and reliability of NLP systems:

A bit on OOD

The term “distribution” has slightly different meanings for Language and Vision tasks. Consider a dog breed image classification task, here the images of dogs would be in-distribution while images like bike, ball, etc. would be out-of-distribution.

For Language Tasks, some associate “change in author, writing style, vocabulary, dataset, etc.” with distribution shift while some correlate it with reasoning skill. Example: For a Question-Answering model trained on Maths questions, a question from History is OOD.

Why is OOD Detection important?

In real-world tasks, the data distribution usually drifts over time, and chasing an evolving data distribution is costly. Hence, OOD detection is important to prevent AI systems from making prediction errors.

Why Models have OOD brittleness?

  1. Neural Network models can rely heavily on spurious cues and annotation artifacts present in the training data while OOD examples are unlikely to contain the same spurious patterns as in-distribution examples.
  2. The training data can’t cover all the facets of a distribution hence limiting the model’s generalization ability.

Types of Generalizations:

  • In-Distribution Generalization — Generalization to examples that are novel but drawn from the same distribution as the training set.
  • Out-of-Distribution Generalization — Generalization to examples that are drawn from a different distribution than the training set.

Plausible Reasons for Higher Robustness of pre-trained models (like BERT) than Traditional Models:

  • Pre-training with diverse data.
  • Pre-training with self-supervised learning objectives.
    Note: Interested readers can get a quick refresher on BERT here.

Other Related Problems:

  1. Success and Error prediction — Predicting whether the model will solve an example correctly or not.
  2. Selective Prediction — Making a prediction only for those examples where the model is sufficiently confident. This is especially useful in settings where errors are costly but abstention (not making a prediction) is acceptable.
  3. Domain Adaptation —Extrapolating from training data to test data from a different distribution. We try to generalize a model to new distributions but assume some knowledge about the test distribution, such as unlabeled examples or a few labeled examples.

There are several other related tasks like outlier detection, anomaly detection, etc.

Approaches to Detect OOD instances:

One class of OOD detection techniques is based on thresholding over the prediction confidence i.e computing prediction confidence and labeling the example as OOD if the confidence is below a certain threshold. Another class of techniques trains to optimize both the task (classification or regression) and OOD detection simultaneously.

1. Maximum Softmax Probability (MaxProb):

For classification problems, neural network model outputs a vector known as logits. The logits vector is passed through a softmax function to get class probabilities.

Softmax Function

The maximum softmax probability (MaxProb) (i.e maximum softmax across all classes) is used as the prediction confidence. This is one of the simplest yet strong OOD detection techniques. The intuition behind this approach is:

Correctly classified examples tend to have greater MaxProb than erroneously classified and out-of-distribution examples. In other words, more confident predictions indeed tend to be more accurate.

2. Ensembling of Multiple Models:

In Ensemble Learning, multiple models (trained using the same learning algorithm) are used to make predictions for each data point and the decisions from these models are combined to improve the overall performance. There are various ways of combing decisions:

  • Max Voting — The prediction which we get from the majority of the models is used as the final prediction.
  • Averaging — Averaging the predictions of all models is straightforward for regression tasks but for classification tasks, we can average the softmax probabilities.
  • Weighted Averaging — In this technique, models are assigned different weights and a weighted average is taken to compute the final prediction.

Here, combining the decisions means to compute the prediction confidence from multiple models.

Let’s learn the concept of calibration before proceeding to the remaining approaches.

Calibration:

Calibration is the problem of predicting probability estimates representative of the ground truth correctness likelihood. For this, a model should output is prediction along with the confidence measure. Let’s try to understand this mathematically.

Let ‘f’ be a neural network with f(x) = (y, p) where y is the prediction and p is the prediction confidence i.e probability of correctness.

We would like p to be calibrated i.e p represents a true probability.

Perfect Calibration:

For instance, if a perfectly calibrated predictor makes 100 predictions each with a confidence of 0.8 then 80 predictions should be correct.

Reliability Diagram:

A reliability diagram is used to represent model calibration. It plots expected sample accuracy as a function of confidence.

Sample Reliability Diagrams. Source: Guo, Chuan, et al.

Examples are grouped into various bins based on the prediction confidence value and the accuracy (i.e ratio of correctly answered examples) for each bin is calculated. Let there be a total of M bins, Bₘ represents the mᵗʰ bin i.e it contains samples with confidence in the range: ((m-1)/M, m/M].

In case of perfect calibration, the diagram should plot the identity function. Any deviation from a perfect diagonal (up or down) represents miscalibration.

Note: Reliability diagrams do not display the proportion of samples in a given bin (as it just depicts the ratio of correctly answered samples) and thus can not be used to estimate how many samples are calibrated.

Measuring Calibration Error:

  • Expected Calibration Error

where n is the total number of samples.

ECE is a weighted average of the bins’ accuracy/confidence difference.

  • Maximum Calibration Error

MCE is the maximum of the bins’ accuracy/confidence difference.

  • Negative Log-Likelihood
    This is same as the cross-entropy loss.

We discuss other relevant evaluation metrics in the supplementary section.

3. Temperature Scaling

For MaxProb, the prediction confidence is computed using the softmax function. Temperature Scaling is an extension of Platt scaling that uses a single scalar parameter T > 0. Here, the prediction confidence is computed using a function ‘q’ shown below.

Temperature Scaling

It “softens” the softmax (i.e. raises the output entropy) with T > 1. This makes the network slightly less confident, which makes the confidence scores reflect true probabilities.

  • As T → ∞, the probability approaches 1/J, which represents maximum uncertainty.
  • With T = 1, the original softmax probability is recovered.
  • Parameter T is learned with respect to Negative Log-Likelihood on the validation set.
  • Since the parameter T does not change the maximum of the softmax function (i.e the class prediction as the probability of all classes is scaled), temperature scaling does not affect the model’s accuracy.

4. Training a Binary Classification model as a Calibrator:

It requires evaluating the trained model on a held-out dataset and annotating correctly answered examples as positive and incorrectly answered examples as negative (Note that this step is independent of the actual label of the examples). Then a binary classification model can be trained on this annotated dataset to predict new examples as belonging to positive or negative class. Though this approach is more suitable for Success and Error Prediction task, it can be easily modified for OOD detection by incorporating OOD examples while training the calibrator.

5. Monte-Carlo Dropout or Test-Time Dropout

Dropout is a way of preventing overfitting. Usually, dropout is disabled at test time, but it is observed that dropout gives good confidence estimates on OOD data. In this approach, the input is passed through the network with K different dropout masks. Two statistics are commonly used as confidence measures: Mean and Variance. This is similar to ensembling but with K different dropout masks.

This approach has a few drawbacks:

  • It requires access to internal model representations as different dropout masks are used.
  • It requires K forward passes of the model, thus increasing runtime by K times.

Other methods like SelectiveNet, ODIN, etc. will soon be added to this article.

Supplementary:

Evaluation Metrics

In this section, we’ll describe the popular metrics: AUROC, AUPR, FAR95, and the Risk-Coverage Curve.

AUROC (Area under Receiver Operating Characteristics):

  • ROC curve is a graph showing the performance of a classification model at all classification thresholds.
  • It’s a plot of False Positive Rate (x-axis) and True Positive Rate (y-axis).
  • Measures how much model is capable of distinguishing between classes.
  • Typically used for binary classification problems but can be extended for multi-class settings as well.
Demonstrating ROC Curve. Source: Glass Box.

This curve (especially its name) seems daunting but let’s dive deeper and try to understand it.

  • A higher X-axis value indicates a higher number of False positives than True negatives.
  • A higher Y-axis value indicates a higher number of True positives than False negatives.
  • There is a point for every classification threshold on this curve. (assume that the model labels all the instances with a score above the threshold as positive (class 1) while the ones below as negative (class 0)). Hence, lowering the classification threshold classifies more items as positive, thus increasing both False Positives and True Positives.
  • AUC provides an aggregate measure of performance across all possible classification thresholds.
  • AUC value ranges from 0 to 1. A model whose predictions are 100% wrong has an AUC of 0 while the one whose predictions are 100% correct has an AUC of 1.
  • The curves of different models can be compared directly in general or for different thresholds. Higher the AUC, better the model is at predicting 0s as 0s and 1s as 1s or negatives as negative and positives as positive.
    Hence, Blue > Yellow > Green. , we can compare different points for the same model (i.e different thresholds) depending on how many false positives and false negatives we want to tolerate for our classifier.
  • AUC close to 1 → the model has a good measure of separability.
  • AUC of 0 → the model is reciprocating the result i.e predicting 0s as 1s and 1s as 0s.
  • AUC of 0.5 → the model has no discrimination capacity to distinguish between positive class and negative class.
  • All points above the diagonal line (Random Classifier — red line) correspond to the situation where the proportion of correctly classified points belonging to the Positive class is greater than the proportion of incorrectly classified points belonging to the Negative class.

AUROC is not ideal when the positive class and negative class have greatly differing base rates, and the AUPR adjusts for these different positive and negative base rates.

AUPR (Area under Precision-Recall Curve):

  • ROC curves are appropriate when the observations are balanced between each class, whereas precision-recall curves are appropriate for imbalanced datasets.
PR curve. Source: towardsdatascience.com
  • A no-skill classifier is one that cannot discriminate between the classes and would predict a random class or a constant class in all cases. The no-skill line changes based on the distribution of the positive to negative classes. It is a horizontal line with the value of the ratio of positive cases in the dataset. For a balanced dataset, this is 0.5.

FAR95 (False Alarm Rate):

  • The FAR95 is the probability that an in-distribution example raises a false alarm, assuming that 95% of all out-of-distribution examples are detected.
  • Hence a lower FAR95 is better.

Risk-Coverage Curve:

A popular metric in selective answering literature.

Sample Risk-Coverage Curve.

It’s a plot between Coverage (x-axis) and Risk (y-axis). Coverage is the fraction of examples the model makes a prediction on while Risk is the error on that fraction of examples. For any choice of threshold, a model has an associated coverage and risk. Area under the curve (AUC) is used to compare the performance of different models.

Check out my related articles:

References:

  1. Hendrycks, Dan, and Kevin Gimpel. “A baseline for detecting misclassified and out-of-distribution examples in neural networks.”
  2. McCoy, R. Thomas, Junghyun Min, and Tal Linzen. “Berts of a feather do not generalize together: Large variability in generalization across models with similar test set performance.”
  3. Hendrycks, Dan, et al. “Pretrained transformers improve out-of-distribution robustness.”
  4. Geifman, Yonatan, and Ran El-Yaniv. “Selectivenet: A deep neural network with an integrated reject option.”
  5. Guo, Chuan, et al. “On calibration of modern neural networks.”

--

--

Neeraj Varshney
Analytics Vidhya

Looking for full-time positions | Ph.D. Candidate working in Natural Language Processing (https://nrjvarshney.github.io)