Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numercially stable softmax

Is there a numerically stable way to compute softmax function below? I am getting values that becomes Nans in Neural network code.

np.exp(x)/np.sum(np.exp(y))
like image 225
Abhishek Bhatia Avatar asked Mar 04 '17 18:03

Abhishek Bhatia


People also ask

What is the problem of implementing the softmax function?

The common problem which can occur while applying softmax is the numeric stability problem, which means that the ∑j e^(z_j) may become very large due to the exponential and overflow error that may occur. This overflow error can be solved by subtracting each value of the array with its max value.

What is soft in softmax?

Definition. The softmax function takes as input a vector z of K real numbers, and normalizes it into a probability distribution consisting of K probabilities proportional to the exponentials of the input numbers.

Why is softmax unstable?

From the softmax probabilities above, we can deduce that softmax can become numerically unstable for values with a very large range. Consider changing the 3rd value in the input vector to 10000 and re-evaluate the softmax. 'nan' stands for not-a-number and occurs when there is an overflow or underflow.


3 Answers

The softmax exp(x)/sum(exp(x)) is actually numerically well-behaved. It has only positive terms, so we needn't worry about loss of significance, and the denominator is at least as large as the numerator, so the result is guaranteed to fall between 0 and 1.

The only accident that might happen is over- or under-flow in the exponentials. Overflow of a single or underflow of all elements of x will render the output more or less useless.

But it is easy to guard against that by using the identity softmax(x) = softmax(x + c) which holds for any scalar c: Subtracting max(x) from x leaves a vector that has only non-positive entries, ruling out overflow and at least one element that is zero ruling out a vanishing denominator (underflow in some but not all entries is harmless).

Footnote: theoretically, catastrophic accidents in the sum are possible, but you'd need a ridiculous number of terms. For example, even using 16 bit floats which can only resolve 3 decimals---compared to 15 decimals of a "normal" 64 bit float---we'd need between 2^1431 (~6 x 10^431) and 2^1432 to get a sum that is off by a factor of two.

like image 177
Paul Panzer Avatar answered Oct 05 '22 18:10

Paul Panzer


Softmax function is prone to two issues: overflow and underflow

Overflow: It occurs when very large numbers are approximated as infinity

Underflow: It occurs when very small numbers (near zero in the number line) are approximated (i.e. rounded to) as zero

To combat these issues when doing softmax computation, a common trick is to shift the input vector by subtracting the maximum element in it from all elements. For the input vector x, define z such that:

z = x-max(x)

And then take the softmax of the new (stable) vector z


Example:

def stable_softmax(x):
    z = x - max(x)
    numerator = np.exp(z)
    denominator = np.sum(numerator)
    softmax = numerator/denominator

    return softmax

# input vector
In [267]: vec = np.array([1, 2, 3, 4, 5])
In [268]: stable_softmax(vec)
Out[268]: array([ 0.01165623,  0.03168492,  0.08612854,  0.23412166,  0.63640865])

# input vector with really large number, prone to overflow issue
In [269]: vec = np.array([12345, 67890, 99999999])
In [270]: stable_softmax(vec)
Out[270]: array([ 0.,  0.,  1.])

In the above case, we safely avoided the overflow problem by using stable_softmax()


For more details, see chapter Numerical Computation in deep learning book.

like image 43
kmario23 Avatar answered Oct 05 '22 17:10

kmario23


Extending @kmario23's answer to support 1 or 2 dimensional numpy arrays or lists (common if you're passing a batch of results through the softmax function):

import numpy as np


def stable_softmax(x):
    z = x - np.max(x, axis=-1, keepdims=True)
    numerator = np.exp(z)
    denominator = np.sum(numerator, axis=-1, keepdims=True)
    softmax = numerator / denominator
    return softmax


test1 = np.array([12345, 67890, 99999999])  # 1D numpy
test2 = np.array([[12345, 67890, 99999999], # 2D numpy
                  [123, 678, 88888888]])    #
test3 = [12345, 67890, 999999999]           # 1D list
test4 = [[12345, 67890, 999999999]]         # 2D list

print(stable_softmax(test1))
print(stable_softmax(test2))
print(stable_softmax(test3))
print(stable_softmax(test4))

 [0. 0. 1.]

[[0. 0. 1.]
 [0. 0. 1.]]

 [0. 0. 1.]

[[0. 0. 1.]]
like image 36
David Parks Avatar answered Oct 05 '22 17:10

David Parks