Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Avoid overflow with softplus function in python

I am trying to implement the following softplus function:

log(1 + exp(x))

I've tried it with math/numpy and float64 as data type, but whenever x gets too large (e.g. x = 1000) the result is inf.

Can you assist me on how to successfully handle this function with large numbers?

like image 724
Lotzki Avatar asked May 28 '17 18:05

Lotzki


2 Answers

TLDR:

import numpy as np
import math

def softplus_np(x): return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)
def softplus_math(x): return math.log1p(math.exp(-abs(x))) + max(x, 0)

Explanation: There is a relation which one can use:

log(1+exp(x)) = log(1+exp(x)) - log(exp(x)) + x = log(1+exp(-x)) + x

So a safe implementation, as well as mathematically sound, would be:

log(1+exp(-abs(x))) + max(x,0)

This works both for math and numpy functions (use e.g.: np.log, np.exp, np.abs, np.maximum).

like image 105
David S. Avatar answered Oct 17 '22 23:10

David S.


Since for x>30 we have log(1+exp(x)) ~= log(exp(x)) = x, a simple stable implementation is

def safe_softplus(x, limit=30):
  if x>limit:
    return x
  else:
    return np.log1p(np.exp(x))

In fact | log(1+exp(30)) - 30 | < 1e-10, so this implementation makes errors smaller than 1e-10 and never overflows. In particular for x=1000 the error of this approximation will be much smaller than float64 resolution, so it is impossible to even measure it on the computer.

like image 20
lejlot Avatar answered Oct 17 '22 23:10

lejlot