How to add uncertainty to your neural network

Alvaro Durán Tovar
Deep Learning made easy
5 min readNov 15, 2019

Recently in my job I have been told to add uncertainty to our models, to find a way to return not just a prediction but how certain is the model about it, being able to calculate standard deviations, percentiles, confidence intervals, etc.

Quickly came to my mind tensorflow probability and this video, and indeed the project have been a quick success:

Using TF probability was super simple. In this article I’m going to explain how to do it yourself with pytorch, just for fun.

First of all, calling it uncertainty sounds super cool, but in reality what we are doing is obtaining a probability distribution and then using it… for example to calculate the uncertainty, but the key thing is having the probability distribution.

Another super cool article that I had for a long time in the back of my mind is this one: Uncertainty for CTR Prediction: One Model to Clarify Them All, the article speaks about how they are calculating uncertainty for regression problems.

So probability distributions… what do we need to have a probability distribution? Well, that depends on the family of the distribution, as you probably know there are many, the most widely used is the normal distribution. To define a normal distribution we only need two parameters: the mean and the standard deviation, sometimes called location and scale, because the mean specifies where is the center of the distribution and the standard deviation how wide it is, in other words the scale. Two values, keep that in mind for a while.

How do we connect MSE with a normal distribution?

Suppose you have a regression problem and very likely the loss you select is the MSE. What this loss does is to find the value that minimizes the error, above and below, the value that have the minimum distance between all samples. In other words, is calculating the mean. Well that’s obvious as the loss is named “mean squared error”. Ok, ok, that’s super simple, lets move on.

The thing is that maybe you didn’t noticed it, but you already have one of the two values that we need to define a normal distribution, we already have the mean. We can’t use it in the same way, but what I’m trying to say is that isn’t hard to obtain a distribution from a neural network, you just have to do things in a different way.

The trick

And the trick is… return two values. WTF!? Yeah, that’s it, you just have to return two values from your model instead of one, as simple as that. The details to make this work are bit harder tho, but the main idea is that simple. Here you have an example where we return two numbers and directly instantiate a normal distribution with them, and then we return the distribution it self.

The mean of the distribution should aline with the value obtained if you were using MSE.

The loss

We aren’t returning a scalar now, instead as we are returning a probability distribution, therefore what we have to do is calculate “how likely is this distribution to produce this values”. As we are dealing with probabilities maybe you have the answer in your mind already, we will use negative log likelihood as the loss.

It’s super simple, “dist” is the output from the model, what you might call it “out” or “y_hat”, now instead of being a number is a distribution and target is obviously the target, the known “y”.

Why this works?

So why (and not how) this works is because of the same magic that happens with all use cases with neural networks, the magic of the back propagation. If we ask the network to act in some way (and we setup everything correctly) the network will do the best to do so. How can we categorize images between cats and dogs? How can we make fake images, fake news, do style transfer, etc? Because we ask the network to do so, and somehow it works. This might sound kind of stupid, but this still amaze me that if you ask the network to return a mean and a standard deviation it will. I still feel like a child seeing magic with this things.

How does this works?

This took me a bit of time to understand it but I got it finally. For a given input we obtain the parameters mean and std. The network might need to move the mean and/or change the std to increase the probability.

As you can see in the loss we are using the logarithm of the probability. We use the logarithm because we will obtain very very very small numbers close to 0 and using the log helps us to avoid problems with floating point precision (better handle -100 than 3.720075976020836e-44, log(3.720075976020836e-44) == -100).

Say we have the dataset below. We use 7.5 as input and obtain (8, 1) as mean and std respectively. What’s the probability of seeing 4 on this distribution? 0.00013383022576488537, quite low. In code it might look like the following:

And the loss would be:

Then the network change its weights to reduce the loss, to increase the probability.

How does this looks like?

Of course if depends on your specific problem, but with the toy problem I have been using it looks like this:

The dashed lines are mean + 2 times the standard deviation. Remember 2 stds includes 95% of the samples. Some samples still lands outside of the range but that’s expected.

Lessons learned

  • I wasted lot of time trying to figure out why the model wasn’t learning at all and it was because I was using incorrect shapes for the parameters of the distribution. It must be like [batch_size, 1], not [batch_size], otherwise the result for some matrix multiplications will end up being wrong because of the broadcasting.
  • I struggled also with having lot of nans and at the end I figured out it was because of negative values for the standard deviation, that’s why I use torch.clamp, to ensure it’s always a positive number.
  • Initialization matters and using kaiming initialization seems to help for this dataset.
  • Never give up and when you don’t know what is failing take a closer look into what values are flowing through the code.

Code

You can find the notebook on google colab here https://colab.research.google.com/drive/18BGOeQjk1VgxuxjYzCXPy1DpC3LqVjVR

--

--