Are You Messing With Me Softmax?

Numerical instability and weirdness of the softmax function

Lahiru Nuwan Wijayasingha
The Startup
6 min readAug 12, 2020

--

Once upon a time, I was trying to train a speaker recognition model with TIMIT dataset. I used Alexnet since I wanted to try this with a smaller model first. I have used a softmax layer at the end. The inputs were spectrograms of voices of different people and labels were the speaker IDs. MSELoss was used with the PyTorch library. I left the model for hours to train but to no avail. I was wondering why.

I checked the output from the model (the output from the softmax). The elements of the output array were all equal to each other, for all the inputs I tried. This was really annoying. It seemed that the model did not learn anything at all. So I set out to investigate. This article contains some of my findings about the softmax function. First let’s examine the softmax function

The equation above shows the softmax function of a vector x. As we can see softmax function contains exponential terms. The result of the exponential function can get very large with increasing input. Therefore for sufficiently large inputs, overflow errors can occur! So we need to make sure that the input does not get too large to cause this. Here by input I mean the input to the softmax function. So I tried to find when the exponential term gives an overflow error. The largest value without overflow was found to be 709 (at least in my machine).

Note that this value could change from machine to machine and from the library to library.

Next I set out to explore how softmax behaves for large and small inputs. So I created input arrays sampled from normal distributions with varying mean values. And then plotted the statistics after taking the the softmax function. The size of the input was chosen to be 1000.

The code I used to do this is as follows (in python)

The following figure shows the mean values of the softmax function plotted against mean value of the input feature. As expected, it is constant. Well that is good so far.

Now let’s plot the max value of softmax vs the mean value of the input feature vector.

We can see that for very small inputs, the result of softmax 0.001 (I had to print the array values for this). The input array had 1000 elements. Seems like under small inputs, softmax divides the output probabilities equally (1/1000) to the components even though the elements in the input feature array are not equal.

Further insight can be taken by looking at the plot of standard deviation of the softmax values plotted against the mean value of the input feature vector.

The SD reaches 0 (meaning no variation among the probability values) when the input in small. Well this is not good. Lets observe a numerical example.

As we can see inputs in the scale of 1e-8 causes softmax to output similar values making these useless. Well this was what happened to my model.

And again something awful happens when the inputs are very large. The max value of the softmax reaches 1. This means the other values must be close to 0. It can be seen that SD value also plateau. Now for larger values,

Only the 3rd element is 1 in the softmax output. The other are almost zero.

We can see that softmax does not represent the inputs distribution well for inputs too large or too small. So if our model produces values in these ranges before the softmax, the model will not learn anything because softmax is useless.

What can we do about this ?

Scaling the input

One solution to these problems is standardizing the inputs before we send them to softmax.

After doing this the plot of max value of the softmax against mean value of input features looked like below.

It looks like those awkward values at very small and large input values are gone now which is good.

Using log-softmax

Sometimes taking log value of softmax can make the operation more stable. The equation for log softmax is simply taking log value of softmax (obviously!). But there are certain implications of doing this.

On closer inspection, we can see that log softmax can be converted to the following form

This simplifies things a lot. The second term on the right hand side of the equation can be simplified with the method commonly called the log-sum-exp trick. This prevents overflow and underflow errors making log softmax more stable than bare softmax. Most of the libraries which calculate log-softmax use this log-sum-exp trick. For more about this, read this article :

Let’s take a look at the variation of mean value of log softmax as we change the mean value of input to log softmax

Next max value of log softmax

Then standard deviation

Looks like standard deviation increases when we increase the mean input feature values. The standard deviation does not reach zero when the input feature mean increases like in earlier problematic cases.

In fact from my experiments I saw that using log-softmax, the model trained faster than with the min-max scaling. If you are using PyTorch, CrossEntropyLoss can be used since it already contains log-softmax.

For my Jupyter Notebook (with plots and all) go to

Takeaway

Sometimes softmax can be numerically unstable (give overflow or underflow errors) or useless (all the outputs are the same or weird). So if your model is reluctant to learn anything, it could be due to this. In this case solve this problem with something like log-softmax. But some solutions may not work for a particular application. So we may have to experiment a bit.

Follow me on Twitter https://twitter.com/lahirunuwan

Linkedin : www.linkedin.com/in/lahiru-nuwan-59568a88

--

--