I am very new to Scala and I would like to translate my Java code with the same performance level.
given n float vectors and an additional vector, I have to compute all n dot products and get the maximum one.
Using Java is pretty straightforward for me
public static void main(String[] args) {
int N = 5000000;
int R = 200;
float[][] t = new float[N][R];
float[] u = new float[R];
Random r = new Random();
for (int i = 0;i<N;i++) {
for (int j = 0;j<R;j++) {
if (i == 0) {
u[j] = r.nextFloat();
}
t[i][j] = r.nextFloat();
}
}
long ts = System.currentTimeMillis();
float maxScore = -1.0f;
for (int i = 0;i < N;i++) {
float score = 0.0f;
for (int j = 0; i < R;i++) {
score += u[j] * t[i][j];
}
if (score > maxScore) {
maxScore = score;
}
}
System.out.println(System.currentTimeMillis() - ts);
System.out.println(maxScore);
}
The compute time is 6 ms on my machine.
Now I have to do it with Scala
val t = Array.ofDim[Float](N,R)
val u = Array.ofDim[Float](R)
// Filling with random floats like in Java
val ts = System.currentTimeMillis()
var maxScore: Float = -1.0f
for ( i <- 0 until N) {
var score = 0.0f
for (j <- 0 until R) {
score += u(j) * t(i)(j)
}
if (score > maxScore) {
maxScore = score
}
}
println(System.currentTimeMillis() - ts)
println(maxScore);
The above code takes more than on second on my machine. My thought is that Scala has no primitive array structure such as float[] in Java, and is replaced by a collection. The access at index i seems to be slower than the one with primitive array in Java.
The below code is even slower:
val maxScore = t.map( r => r zip u map Function.tupled(_*_) reduceLeft (_+_)).max
which takes 26s
How should I efficiently iterate over my 2 arrays to compute this ?
Thanks a lot
Well, sorry to say but the odd thing here is how fast your Java implementation is, not how slow your Scala one is - 6ms for traversing 10 billion (!) cells sounds too good to be true - and indeed - you have a typo in the Java implementation that makes this code do much less:
instead of for (int j = 0; j < R;j++)
, you have for (int j = 0; i < R;i++)
- which makes the inner loop run only 200 times instead of 10 Billion...
If you fix this - the Scala and Java performance are comparable.
This, BTW, is actually an advantage of Scala - it's harder to get for (j <- 0 until R)
wrong :)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With