Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Lookup table Fast Sigmoidal function

Tags:

java

private static double [] sigtab = new double[1001];  // values of f(x) for x values 

static {
  for(int i=0; i<1001; i++) {
      double ifloat = i;
      ifloat /= 100;
      sigtab[i] = 1.0/(1.0 + Math.exp(-ifloat));
  }
}

public static double fast_sigmoid (double x) {
    if (x <= -10)
        return 0.0;
    else if (x >= 10)
        return 1.0;
    else {
        double normx = Math.abs(x*100);
        int i = (int)normx;
        double lookup = sigtab[i] + (sigtab[i+1] - sigtab[i])*(normx - Math.floor(normx));
        if (x > 0)
            return lookup;
        else // (x < 0)
            return (1 - lookup);
    }
}

Anyone know why this "fast sigmoid" actually runs slower than the exact version using Math.exp?

like image 861
ShahQermez Avatar asked Feb 13 '26 23:02

ShahQermez


1 Answers

You should profile your code, but I'll bet it's the call to Math.floor taking around half your CPU cycles (it is slow because it calls the native method StrictMath.floor(double), incurring the JNI overhead.)

It is possible to compute (less-accurate) versions of sigmoid functions faster than the (exact) hardware implementations. Here's an example for tanh, which should be easy to transform to your function (is it expit(-x)?)

Two tricks that are used here are often useful in LUT-based approximations:

  • Simulate rounding by adding a large constant (forcing the FPU will truncate it, having too few bits to represent the sum)
  • Make your table size a power of 2 (means one less multiply per call)

public static float fastTanH(float x) {
    if (x<0) return -fastTanH(-x);
    if (x>8) return 1f;
    float xp = TANH_FRAC_BIAS + x;
    short ind = (short) Float.floatToRawIntBits(xp);
    float tanha = TANH_TAB[ind];
    float b = xp - TANH_FRAC_BIAS;
    x -= b;
    return tanha + x * (1f - tanha*tanha);
}

private static final int TANH_FRAC_EXP = 6; // LUT precision == 2 ** -6 == 1/64
private static final int TANH_LUT_SIZE = (1 << TANH_FRAC_EXP) * 8 + 1;
private static final float TANH_FRAC_BIAS =
    Float.intBitsToFloat((0x96 - TANH_FRAC_EXP) << 23);
private static float[] TANH_TAB = new float[TANH_LUT_SIZE];
static {
    for (int i = 0; i < TANH_LUT_SIZE; ++ i) {
        TANH_TAB[i] = (float) Math.tanh(i / 64.0); 
    }
}
like image 93
finnw Avatar answered Feb 15 '26 12:02

finnw



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!