Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala dot product very slow compared to Java

I am very new to Scala and I would like to translate my Java code with the same performance level.

given n float vectors and an additional vector, I have to compute all n dot products and get the maximum one.

Using Java is pretty straightforward for me

public static void main(String[] args) {

    int N = 5000000;
    int R = 200;
    float[][] t = new float[N][R];
    float[] u = new float[R];

    Random r = new Random();

    for (int i = 0;i<N;i++) {
        for (int j = 0;j<R;j++) {
            if (i == 0) {
                u[j] = r.nextFloat();
            }
            t[i][j] = r.nextFloat();
        }
    }

    long ts = System.currentTimeMillis();
    float maxScore = -1.0f;

    for (int i = 0;i < N;i++) {
        float score = 0.0f;
        for (int j = 0; i < R;i++) {
            score += u[j] * t[i][j];
        }
        if (score > maxScore) {
            maxScore = score;
        }

    }

    System.out.println(System.currentTimeMillis() - ts);
    System.out.println(maxScore);

}

The compute time is 6 ms on my machine.

Now I have to do it with Scala

val t = Array.ofDim[Float](N,R)
val u = Array.ofDim[Float](R)

// Filling with random floats like in Java

val ts = System.currentTimeMillis()
var maxScore: Float = -1.0f

for ( i <- 0 until N) {
  var score = 0.0f
  for (j <- 0 until R) {
    score += u(j) * t(i)(j)
  }
  if (score > maxScore) {
    maxScore = score
  }

}

println(System.currentTimeMillis() - ts)
println(maxScore);

The above code takes more than on second on my machine. My thought is that Scala has no primitive array structure such as float[] in Java, and is replaced by a collection. The access at index i seems to be slower than the one with primitive array in Java.

The below code is even slower:

val maxScore = t.map( r => r zip u map Function.tupled(_*_) reduceLeft (_+_)).max

which takes 26s

How should I efficiently iterate over my 2 arrays to compute this ?

Thanks a lot

like image 791
ogen Avatar asked Nov 28 '22 20:11

ogen


1 Answers

Well, sorry to say but the odd thing here is how fast your Java implementation is, not how slow your Scala one is - 6ms for traversing 10 billion (!) cells sounds too good to be true - and indeed - you have a typo in the Java implementation that makes this code do much less:

instead of for (int j = 0; j < R;j++), you have for (int j = 0; i < R;i++) - which makes the inner loop run only 200 times instead of 10 Billion...

If you fix this - the Scala and Java performance are comparable.

This, BTW, is actually an advantage of Scala - it's harder to get for (j <- 0 until R) wrong :)

like image 150
Tzach Zohar Avatar answered Dec 11 '22 02:12

Tzach Zohar