Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is the performance of these matrix multiplications so different?

I wrote two matrix classes in Java just to compare the performance of their matrix multiplications. One class (Mat1) stores a double[][] A member where row i of the matrix is A[i]. The other class (Mat2) stores A and T where T is the transpose of A.

Let's say we have a square matrix M and we want the product of M.mult(M). Call the product P.

When M is a Mat1 instance the algorithm used was the straightforward one:

P[i][j] += M.A[i][k] * M.A[k][j]
    for k in range(0, M.A.length)

In the case where M is a Mat2 I used:

P[i][j] += M.A[i][k] * M.T[j][k]

which is the same algorithm because T[j][k]==A[k][j]. On 1000x1000 matrices the second algorithm takes about 1.2 seconds on my machine, while the first one takes at least 25 seconds. I was expecting the second one to be faster, but not by this much. The question is, why is it this much faster?

My only guess is that the second one makes better use of the CPU caches, since data is pulled into the caches in chunks larger than 1 word, and the second algorithm benefits from this by traversing only rows, while the first ignores the data pulled into the caches by going immediately to the row below (which is ~1000 words in memory, because arrays are stored in row major order), none of the data for which is cached.

I asked someone and he thought it was because of friendlier memory access patterns (i.e. that the second version would result in fewer TLB soft faults). I didn't think of this at all but I can sort of see how it results in fewer TLB faults.

So, which is it? Or is there some other reason for the performance difference?

like image 496
CromTheDestroyer Avatar asked Oct 27 '10 00:10

CromTheDestroyer


1 Answers

This because of locality of your data.

In RAM a matrix, although bidimensional from your point of view, it's of course stored as a contiguous array of bytes. The only difference from a 1D array is that the offset is calculated by interpolating both indices that you use.

This means that if you access element at position x,y it will calculate x*row_length + y and this will be the offset used to reference to the element at position specified.

What happens is that a big matrix isn't stored in just a page of memory (this is how you OS manages the RAM, by splitting it into chunks) so it has to load inside CPU cache the correct page if you try to access an element that is not already present.

As long as you go contiguously doing your multiplication you don't create any problems, since you mainly use all coefficients of a page and then switch to the next one but if you invert indices what happens is that every single element may be contained in a different memory page so everytime it needs to ask to RAM a different page, this almost for every single multiplication you do, this is why the difference is so neat.

(I rather simplified the whole explaination, it's just to give you the basic idea around this problem)

In any case I don't think this is caused by JVM by itself. It maybe related in how your OS manages the memory of the Java process..

like image 117
Jack Avatar answered Oct 07 '22 02:10

Jack