Fast and Scalable Estimation of Uncertainty using Bayesian Deep Learning
Recently, I attended an hour-long presentation by Dr. Emtiyaz Khan, the team lead of the Approximate Bayesian Inference (ABI) Team at RIKEN Center for Advanced Intelligence Project (Tokyo). His talk revolved around learning variance by natural gradients. Systematically, he explained why it is challenging to compute the uncertainty and how they took inspiration from the Adam optimizer, in their latest publication at ICML’18 (which beats the state of the art). Overall, he walked us through variational inference, Bayesian models, natural gradients and fast Gaussian approximation for deep learning models, cogently.
“My main goal is to understand the principles of learning from data and use them to develop algorithms that can learn like living beings.”- Dr.Khan
“Say you have to detect animals like cats and dogs. It is very common because there are a lot of images available. Assume that you are asked to detect a cow from rural India- the number of images are far lesser. Hence, the confidence in predicting cats and dogs is more than that of cows.”- Dr.Khan
Uncertainty helps in estimating the confidence in the predictions of a deep learning system. To explain variance and why it’s important to know what we don’t know, he gave us a simple example of Japan earthquake data. The green squares in Fig.2(a) represent data points. After giving us a brief introduction to the problem statement, he asked the attendees whether the blue dashed line or the continuous red line would be a better fit. While a majority voted for the red line, his response took us by surprise- both the lines were correct. In such a scenario, how does one decide the confidence of the prediction? The observation was the variance (green region) increases with an increase in the magnitude. It implies that earthquakes with lower magnitudes can predict with higher confidence and that uncertainty rises with size.
To visualize uncertainty, he quoted an example of image segmentation. Fig.1 represents the ground truth, prediction and risks of several cityscapes. The uncertainty map plots the variance in segmentation. A cursory glance at the results will show that lighter colored regions or uncertain regions in the uncertainty maps correspond to the areas which have not adequately segmented in the predicted maps. We can obtain more robust predictions if uncertainty is on top of a deep learning system. About the current example, imagine the risks involved if driverless car ventures into footpaths (areas of low confidence) because the prediction maps are not segmented properly!
- Computing the uncertainty helps if the data is unreliable, scarce or missing.
- It is difficult to compute the variance or confidence for massive data and large models.
- Using the concepts of Bayesian statistics, optimization and information geometry, they have worked on a fast computation of uncertainty.
Why is it Difficulty to Compute Uncertainty?
If you’re a deep learning enthusiast, you’re no stranger to the Naive approach in eqn.1. θ is generated from a prior distribution p(θ)(a Gaussian prior in this case). D refers to the data set and θ represents the parameters. Also, x_i and y_i are the input and output terms respectively. The function f is a neural network.
Fig. 2(b) shows the variance of the draws from a known distribution. We randomly draw out f_θ(x_i), compute ‘p’ and rank the data subsequently. The grey lines show that we can have multiple best fits in the adjacent green region. If there are numerous neural networks are used to average them. How narrow is the spread? The narrowness of the range is the uncertainty (green region).
Via Baye’s rule(Eq.2), we obtain the posterior distribution. And, in turn, helps us in finding the mean and the variance of the posterior distribution. However, the intractable integral is a normalisation constant which is hard to calculate. Due to a large number of samples and parameters, the computation as mentioned above is intensive. Additionally, the problem of integration is much harder than that of optimisation. Keeping this challenge in mind, his team used natural-gradients for a fast approximation of the integral term.
Variational Inference with Gradients
Assume a normal distribution N, with a mean (μ) and a variance (σ²). The goal is to find the values of μ and σ² such that the distance between the posterior distribution and N are minimized (eqn.3). But here’s the catch. What if μ and σ² are 10⁶ dimensional vectors?
Intuitively, the KL divergence (eqn.4) is a means by which the similarity or match between two distributions can be measured. Using this, we can select a neural network out of many candidate neural networks, based on the ability of the neural network to explain the data well.
Now that we’ve defined the KL divergence, what is variational inference? In short, the variation inference is used to approximate the intractable posterior distribution with tractable variational distribution for the KL divergence. Traditionally, a gradient descent approach is used to update μ and σ, based on a fixed step size or learning rate(ρ). Eqn.5,6 represent the same.
Natural Gradient Descent
By now, you must be contemplating if they have tweaked the gradient descent significantly. Ironically, a single Fisher Information Matrix term does the trick! Let’s begin with the Fisher Information Matrix given by eqn.7. It can be perceived as the variance of the observed information and is used to compute the covariance matrices associated with maximum-likelihood estimates.
Why do we need the Fisher Information Matrix? You will find your answer in a few lines. Eqn.8 represents the update step of gradient descent, while eqn.9 gives the natural-gradient descent.
Based on the following relation, we can infer that the natural gradients (L.H.S) are an approximation of the gradients in gradient descent (R.H.S). Furthermore, Emtiaz et.al. propose that if we find the right parameters for the Fisher Information Matrix, the natural-gradients can be computed quickly.
- Natural-gradients are defined on a manifold and not in the Euclidean space.
- Euclidean distance is not an apt measure. Assume two Gaussian distributions with mean values of 0 and 25 respectively. Although a fixed Euclidean distance of 25 exists between them, i.e. between the means, they can have different variances. If they have smaller variances, the distributions won’t overlap. However, larger variances may lead to overlapping regions.
Variational Inference with Natural Gradient Descent
The Natural-Gradient variational inference can be found by the equations given below. Eqn.10,11 show the update steps for the mean and the variance.
- The updates can be derived when q belongs to an exponential family. Methods like Kalman filters can be generalized by this.
- The learning rate (β) is scaled by the variance (in blue). Based on the uncertainty, the system can decide whether to take longer or smaller steps.
Fast Computation of Uncertainty
“We approximate by a Gaussian distribution, and find it by “perturbing” the parameters during backpropagation”.- Dr.Khan
At an abstract level, the quotation as mentioned earlier can be understood by assuming that relevant noise is added before back-propagation. This way, variances about different noises can be obtained. Visually, the ultimate goal is to get the bounds for the spread invariance (green region). This approach is analogous to Adam optimisation in the sense that the initial performance (variance spread) is relatively similar in both cases. You’re probably racking your brains to understand the Vadam (Variational Adam)algorithm. Wait for a second; there’s an easy way out! Dr Khan summarised Vadam in 5 simple steps :)
Vadam (Variational Adam)
At this point, you may want to re-visit the Naive approach and Bayesian inference mentioned earlier. A Gaussian distribution (standard normal distribution) is approximated. We assume the following.
After that, the steps in Fig.4 are performed. A variety of neural network models can consider different local minima or global minima. If you’re still asking yourself why Vadam is significant- it helps us in answering what other models we could have used, to get a better fit.
Lastly, the gradient step is given by eqn.12, and the modified part is highlighted in blue. See how simple the update is?
- Vadam converges to the same variance as Adam. The variance is found during the training stage and not after it.
- Once the bounds on the uncertainty are obtained, this can be transferred to a problem statement with similar data distribution.
- Vadam avoids using local minima and reduces over-fitting.
- Replacing gradients with natural-gradients is a faster and more robust approach.
The following graph shows that Vadam has more stability while converging. If you pay close attention to the first 2000 iterations, it’s evident that the drop in loss is more gradual in the case of Vadam.
 Fast and Scalable Bayesian Deep Learning by Weight-Perturbation in Adam, M.E. Khan et.al.,Thirty-fifth International Conference on Machine Learning, 2018.
IISc presentation slides on Dr. Khan’s website.
 Fast yet Simple Natural-Gradient Descent for Variational Inference in Complex Models,M.E. Khan et.al.,ISITA 2018.
Note: This summary is based on both, the notes taken by me and the original seminar slides. Neither has the proposed work been done by me nor do the images belong to me. Feel free to point out any corrections in case of any ambiguity :)