Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Choice of algorithm for incremental floating point mean (java)

I want to compute the mean of a stream of doubles. This is a simple task that only requires storing a double and an int. I was doing this using the apache commons SummaryStatistics class. However, when testing I noticed that the SummaryStatistics mean had floating point errors that my own python implementation did not. Upon further inspection I found out that commons is using a version of the following algorithm:

static double incMean(double[] data) {
    double mean = 0;
    int number = 0;
    for (double val : data) {
        ++number;
        mean += (val - mean) / number;
    }
    return mean;
}

This sometimes results in small floating point errors e.g.

System.out.println(incMean(new double[] { 10, 9, 14, 11, 8, 12, 7, 13 }));
// Prints 10.500000000000002

This is also the mean algorithm used by the guava utility DoubleMath.mean. It seems strange to me that they both use the above algorithm instead of the more naive algorithm:

static double cumMean(double[] data) {
    double sum = 0;
    int number = 0;
    for (double val : data) {
        ++number;
        sum += val;
    }
    return sum / number;
}

System.out.println(cumMean(new double[] { 10, 9, 14, 11, 8, 12, 7, 13 }));
// Prints 10.5

There's two reasons I can conceive of for why one might prefer the former algorithm. One is that if you query the mean a lot during streaming it might be more efficient to only have to copy a value than to do a division, except it seems that the update step is significantly slower, which would almost always outweigh this cost (note, I haven't actually timed the difference).

The other explanation is that the former prevents overflow issues. This doesn't seem to really be the case with floating point numbers, at most this should result in a degradation of the mean. If this error was the case, we should be able to compare the results to the same cumMean done with the BigDecimal class. That results in the following function:

public static double accurateMean(double[] data) {
    BigDecimal sum = new BigDecimal(0);
    int num = 0;
    for (double d : data) {
        sum = sum.add(new BigDecimal(d));
        ++num;
    }
    return sum.divide(new BigDecimal(num)).doubleValue();
}

This should reasonably be the most accurate mean we could get. From a few anecdotal runs of the following code, there doesn't seem to be a significant different between either mean and the most accurate one. Anecdotally they tend to differ from the accurate mean on the digit, and neither is always closer than the other.

Random rand = new Random();
double[] data = new double[1 << 29];
for (int i = 0; i < data.length; ++i)
    data[i] = rand.nextDouble();

System.out.println(accurateMean(data)); // 0.4999884843826727
System.out.println(incMean(data));      // 0.49998848438246
System.out.println(cumMean(data));      // 0.4999884843827622

Does anyone have any justification as to why both apache commons and guava chose the former method instead of the latter?

Edit: The answer to my question seems clear, the answer is that Knuth proposed it in the Art of Programming Vol II 4.2.2 (15) (Thanks to Louis Wasserman for the tip to look at the guava source). However, in the book, Knuth proposes this method to calculate the mean to bootstrap a robust calculation of the standard deviation, not necessarily saying this is the optimal mean calculation. Based on reading more of the chapter I implemented a fourth mean:

static double kahanMean(double[] data) {
    double sum = 0, c = 0;
    int num = 0;
    for (double d : data) {
        ++num;
        double y = d - c;
        double t = sum + y;
        c = (t - sum) - y;
        sum = t;
    }
    return sum / num;
}

Performing the same tests as above (a handful of times, nothing statistically significant), I get the exact same result as the BigDecimal implementation. I can imagine that the knuth mean update is faster than using the more complicated summation method, but the more complicated method seems empirically to be more accurate at estimating the mean, which I would naively expect to also result in better standard deviation updates. Is there any other reason to use the knuth method other than it's likely faster?

like image 838
Erik Avatar asked Feb 20 '14 20:02

Erik


1 Answers

Short answer: the incremental update approach is preferred as a default because it avoids numerical errors, and doesn't take that much more time/space than the sum-and-divide approach.

The incremental update approach is more numerically stable when taking the average of a large number of samples. You can see that in incMean all of the variable are always of order of a typical data value; however in the summed version the variable sum is of order N*mean, this difference in scale can cause problems due to the the finite precision of floating point math.

In the case of float's (16bits) One can construct artificial problem cases: e.g. few rare samples are O(10^6) and the rest are O(1) (or smaller), or generally if you have millions of data points, then the incremental update will provide more accurate results.

These problematic cases are less likely using doubles (which is why your test cases all give pretty much the same result), but for very large data sets with a large spread of values, the same numerical problems could crop up so it's a generally accepted good practice to use the incremental approach to taking averages (and other moments!)

The advantages of the Kahan method is:

  1. There is is only one division operation (incremental approach requires N divisions),

  2. The funky, almost circular math is a technique to mitigate floating point errors that arise in brute-force summation; think of the variable c as a "correction" to apply to the next iteration.

however, it's easier to code (and read) the incremental approach.

like image 192
Dave Avatar answered Oct 14 '22 05:10

Dave