Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast complex number arithmetic in Clojure

I was implementing some basic complex number arithmetic in Clojure, and noticed that it was about 10 times slower than roughly equivalent Java code, even with type hints.

Compare:

(defn plus [[^double x1 ^double y1] [^double x2 ^double y2]]
    [(+ x1 x2) (+ y1 y2)])

(defn times [[^double x1 ^double y1] [^double x2 ^double y2]]
    [(- (* x1 x2) (* y1 y2)) (+ (* x1 y2) (* y1 x2))])

(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1])))) 

output:

"Elapsed time: 69.429796 msecs"
"Elapsed time: 72.232479 msecs"

with:

public static void main( String[] args ) {
  double[] z1 = new double[] { 1, 0 };
  double[] z2 = new double[] { 0, 1 };
  double[] z3 = null;

  long l_StartTimeMillis = System.currentTimeMillis();
  for ( int i = 0; i < 100000; i++ ) {
    z3 = plus( z1, z2 ); // assign result to dummy var to stop compiler from optimising the loop away
  }
  long l_EndTimeMillis = System.currentTimeMillis();
  long l_TimeTakenMillis = l_EndTimeMillis - l_StartTimeMillis;
  System.out.format( "Time taken: %d millis\n", l_TimeTakenMillis );


  l_StartTimeMillis = System.currentTimeMillis();
  for ( int i = 0; i < 100000; i++ ) {
    z3 = times( z1, z2 );
  }
  l_EndTimeMillis = System.currentTimeMillis();
  l_TimeTakenMillis = l_EndTimeMillis - l_StartTimeMillis;
  System.out.format( "Time taken: %d millis\n", l_TimeTakenMillis );

  doNothing( z3 );
}

private static void doNothing( double[] z ) {

}

public static double[] plus (double[] z1, double[] z2) {
  return new double[] { z1[0] + z2[0], z1[1] + z2[1] };
}

public static double[] times (double[] z1, double[] z2) {
  return new double[] { z1[0]*z2[0] - z1[1]*z2[1], z1[0]*z2[1] + z1[1]*z2[0] };
}

output:

Time taken: 6 millis
Time taken: 6 millis

In fact, the type hints don't seem to make a difference: if I remove them I get approximately the same result. What's really strange is that if I run the Clojure script without a REPL, I get slower results:

"Elapsed time: 137.337782 msecs"
"Elapsed time: 214.213993 msecs"

So my questions are: how can I get close to the performance of the Java code? And why on Earth do the expressions take longer to evaluate when running clojure without a REPL?

UPDATE ==============

Great, using deftype with type hints in the deftype and in the defns, and using dotimes rather than repeatedly gives performance as good as or better than the Java version. Thanks to both of you.

(deftype complex [^double real ^double imag])

(defn plus [^complex z1 ^complex z2]
  (let [x1 (double (.real z1))
        y1 (double (.imag z1))
        x2 (double (.real z2))
        y2 (double (.imag z2))]
    (complex. (+ x1 x2) (+ y1 y2))))

(defn times [^complex z1 ^complex z2]
  (let [x1 (double (.real z1))
        y1 (double (.imag z1))
        x2 (double (.real z2))
        y2 (double (.imag z2))]
    (complex. (- (* x1 x2) (* y1 y2)) (+ (* x1 y2) (* y1 x2)))))

(println "Warm up")
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))

(println "Try with dorun")
(time (dorun (repeatedly 100000 #(plus (complex. 1 0) (complex. 0 1)))))
(time (dorun (repeatedly 100000 #(times (complex. 1 0) (complex. 0 1)))))

(println "Try with dotimes")
(time (dotimes [_ 100000]
        (plus (complex. 1 0) (complex. 0 1))))

(time (dotimes [_ 100000]
        (times (complex. 1 0) (complex. 0 1))))

Output:

Warm up
"Elapsed time: 92.805664 msecs"
"Elapsed time: 164.929421 msecs"
"Elapsed time: 23.799012 msecs"
"Elapsed time: 32.841624 msecs"
"Elapsed time: 20.886101 msecs"
"Elapsed time: 18.872783 msecs"
Try with dorun
"Elapsed time: 19.238403 msecs"
"Elapsed time: 17.856938 msecs"
Try with dotimes
"Elapsed time: 5.165658 msecs"
"Elapsed time: 5.209027 msecs"
like image 743
OpenSauce Avatar asked Aug 06 '12 08:08

OpenSauce


2 Answers

The likely reasons for your slow performance are:

  • Clojure vectors are intrinsically more heavyweight data structures than Java double[] arrays. So you have quite a bit of extra overhead in creating and reading vectors.
  • You are boxing doubles as arguments to your functions and also when they are put into vectors. Boxing / unboxing is relatively expensive in this kind of low-level numerical code.
  • The type hints (^double) are not helping you: while you can have primitive type hints on normal Clojure functions, they won't work on vectors.

See this blog post on accelerating primitive arithmetic for some more details.

If you really want fast complex numbers in Clojure, you will probably need to implement them using deftype, something like:

(deftype Complex [^double real ^double imag])

And then define all your complex functions using this type. This will enable you to use primitive arithmetic throughout, and should be roughly equivalent to the performance of well-written Java code.

like image 63
mikera Avatar answered Oct 01 '22 11:10

mikera


  • I don't know much about benchmark testing but it seems that you need to warm up jvm when you start test. So when you do it in REPL it's already warmed up. When you run as script it's not yet.

  • In java you run all loops inside 1 method. No other method except plus and times are called. In clojure you create anonymous function and call repeatedly for calling it. It takes some time. You can replace it with dotimes.

My try:

(println "Warm up")
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))

(println "Try with dorun")
(time (dorun (repeatedly 100000 #(plus [1 0] [0 1]))))
(time (dorun (repeatedly 100000 #(times [1 0] [0 1]))))

(println "Try with dotimes")
(time (dotimes [_ 100000]
        (plus [1 0] [0 1])))

(time (dotimes [_ 100000]
        (times [1 0] [0 1])))

Results:

Warm up
"Elapsed time: 367.569195 msecs"
"Elapsed time: 493.547628 msecs"
"Elapsed time: 116.832979 msecs"
"Elapsed time: 46.862176 msecs"
"Elapsed time: 27.805174 msecs"
"Elapsed time: 28.584179 msecs"
Try with dorun
"Elapsed time: 26.540489 msecs"
"Elapsed time: 27.64626 msecs"
Try with dotimes
"Elapsed time: 7.3792 msecs"
"Elapsed time: 5.940705 msecs"
like image 23
Mikita Belahlazau Avatar answered Oct 01 '22 12:10

Mikita Belahlazau