Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Java 2D array fill - innocent optimization caused terrible slowdown

I've tried to optimize a filling of square two-dimensional Java array with sums of indices at each element by computing each sum once for two elements, opposite relative to the main diagonal. But instead of speedup or, at least, comparable performance, I've got 23(!) times slower code.

My code:

@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OperationsPerInvocation(ArrayFill.N * ArrayFill.N)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
public class ArrayFill {
    public static final int N = 8189;
    public int[][] g;

    @Setup
    public void setup() { g = new int[N][N]; }

    @GenerateMicroBenchmark
    public int simple(ArrayFill state) {
        int[][] g = state.g;
        for(int i = 0; i < g.length; i++) {
            for(int j = 0; j < g[i].length; j++) {
                g[i][j] = i + j;
            }
        }
        return g[g.length - 1][g[g.length - 1].length - 1];
    }

    @GenerateMicroBenchmark
    public int optimized(ArrayFill state) {
        int[][] g = state.g;
        for(int i = 0; i < g.length; i++) {
            for(int j = 0; j <= i; j++) {
                g[j][i] = g[i][j] = i + j;
            }
        }
        return g[g.length - 1][g[g.length - 1].length - 1];
    }
}

Benchmark results:

Benchmark               Mode     Mean   Mean error    Units
ArrayFill.simple        avgt    0.907        0.008    ns/op
ArrayFill.optimized     avgt   21.188        0.049    ns/op


The question:
How could so tremendous performance drop be explained?

P. S. Java version is 1.8.0-ea-b124, 64-bit 3.2 GHz AMD processor, benchmarks were executed in a single thread.

like image 465
leventov Avatar asked Feb 07 '14 23:02

leventov


3 Answers

A side note: Your "optimized" version mightn't be faster at all, even when we leave all possible problems aside. There are multiple resources in a modern CPU and saturating one of them may stop you from any improvements. What I mean: The speed may be memory-bound, and trying to write twice as fast may in one iteration may change nothing at all.

I can see three possible reasons:

  • Your access pattern may enforce bound checks. In the "simple" loop they can be obviously eliminated, in the "optimized" only if the array is a square. It is, but this information is available only outside of the method (moreover a different piece of code could change it!).

  • The memory locality in your "optimized" loop is bad. It accesses essentially random memory locations as there's nothing like a 2D array in Java (only an array of arrays for which new int[N][N] is a shortcut). When iterating column-wise, you use only a single int from each loaded cacheline, i.e., 4 bytes out of 64.

  • The memory prefetcher can have a problem with your access pattern. The array with its 8189 * 8189 * 4 bytes is too big to fit in any cache. Modern CPUs have a prefetcher allowing to load a cache line during in advance, when it spots a regular access pattern. The capabilities of the prefetchers vary a lot. This might be irrelevant here, as you're only writing, but I'm not sure if it's possible to write into a cache-line which hasn't been fetched.

I guess the memory locality is the main culprit:

I added a method "reversed" which works juts like simple, but with

g[j][i] = i + j;

instead of

g[i][j] = i + j;

This "innocuous" change is a performance desaster:

Benchmark                                Mode   Samples         Mean   Mean error    Units
o.o.j.s.ArrayFillBenchmark.optimized     avgt        20       10.484        0.048    ns/op
o.o.j.s.ArrayFillBenchmark.reversed      avgt        20       20.989        0.294    ns/op
o.o.j.s.ArrayFillBenchmark.simple        avgt        20        0.693        0.003    ns/op
like image 158
maaartinus Avatar answered Dec 05 '22 11:12

maaartinus


I wrote version that works faster than "simple". But, I don't know why it is faster (. Here is the code:

class A {
  public static void main(String[] args) {
    int n = 8009;

    long st, en;

    // one
    int gg[][] = new int[n][n];
    st = System.nanoTime();
    for(int i = 0; i < n; i++) {
      for(int j = 0; j < n; j++) {
        gg[i][j] = i + j; 
      }
    }
    en = System.nanoTime();

    System.out.println("\nOne time " + (en - st)/1000000.d + " msc");

    // two
    int g[][] = new int[n][n];
    st = System.nanoTime();
    int odd = (n%2), l=n-odd;
    for(int i = 0; i < l; ++i) {
      int t0, t1;   
      int a0[] = g[t0 = i];
      int a1[] = g[t1 = ++i];
      for(int j = 0; j < n; ++j) {
        a0[j] = t0 + j;
        a1[j] = t1 + j;
      }
    }
    if(odd != 0)
    {
      int i = n-1;
      int a[] = g[i];
      for(int j = 0; j < n; ++j) {
        a[j] = i + j;
      }
    }
    en = System.nanoTime();
    System.out.println("\nOptimized time " + (en - st)/1000000.d + " msc");

    int r = g[0][0]
    //  + gg[0][0]
    ;
    System.out.println("\nZZZZ = " + r);

  }
}

The results are:

One time 165.177848 msc

Optimized time 99.536178 msc

ZZZZ = 0

Can someone explain me we why it is faster?

like image 26
Chen Gupta Avatar answered Dec 05 '22 10:12

Chen Gupta


http://www.learn-java-tutorial.com/Arrays.cfm#Multidimensional-Arrays-in-Memory

Picture: http://www.learn-java-tutorial.com/images/4715/Arrays03.gif

int[][] === array of arrays of values

int[] === array of values

class A {
    public static void main(String[] args) {
        int n = 5000;

        int g[][] = new int[n][n];
        long st, en;

        // one
        st = System.nanoTime();
        for(int i = 0; i < n; i++) {
            for(int j = 0; j < n; j++) {
                g[i][j] = 10; 
            }
        }
        en = System.nanoTime();
        System.out.println("\nTwo time " + (en - st)/1000000.d + " msc");

        // two
        st = System.nanoTime();
        for(int i = 0; i < n; i++) {
            g[i][i] =  20;
            for(int j = 0; j < i; j++) {
                g[j][i] = g[i][j] = 20; 
            }
        }
        en = System.nanoTime();
        System.out.println("\nTwo time " + (en - st)/1000000.d + " msc");

        // 3
        int arrLen = n*n;
        int[] arr = new int[arrLen];
        st = System.nanoTime();
        for(int i : arr) {
            arr[i] = 30;
        }
        en = System.nanoTime();
        System.out.println("\n3   time " + (en - st)/1000000.d + " msc");

        // 4
        st = System.nanoTime();
        int i, j;
        for(i = 0; i < n; i++) {
            for(j = 0; j < n; j++) {
                arr[i*n+j] = 40;
            }
        }
        en = System.nanoTime();
        System.out.println("\n4   time " + (en - st)/1000000.d + " msc");
    }
}

Two time 71.998012 msc

Two time 551.664166 msc

3 time 63.74851 msc

4 time 57.215167 msc

P.S. I'am not a java spec =)

like image 26
VoidVolker Avatar answered Dec 05 '22 10:12

VoidVolker