Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient summation in OCaml

Please note I am almost a complete newbie in OCaml. In order to learn a bit, and test its performance, I tried to implement a module that approximates Pi using the Leibniz series.

My first attempt led to a stack overflow (the actual error, not this site). Knowing from Haskell that this may come from too many "thunks", or promises to compute something, while recursing over the addends, I looked for some way of keeping just the last result while summing with the next. I found the following tail-recursive implementations of sum and map in the notes of an OCaml course, here and here, and expected the compiler to produce an efficient result.

However, the resulting executable, compiled with ocamlopt, is much slower than a C++ version compiled with clang++. Is this code as efficient as possible? Is there some optimization flag I am missing?

My complete code is:

let (--) i j =
  let rec aux n acc =
    if n < i then acc else aux (n-1) (n :: acc)
    in aux j [];;


let sum_list_tr l =
  let rec helper a l = match l with
    | [] -> a
    | h :: t -> helper (a +. h) t
  in helper 0. l


let rec tailmap f l a = match l with
  | [] -> a
  | h :: t -> tailmap f t (f h :: a);;


let rev l =
    let rec helper l a = match l with
      | [] -> a
      | h :: t -> helper t (h :: a)
    in helper l [];;


let efficient_map f l = rev (tailmap f l []);;


let summand n =
  let m = float_of_int n
  in (-1.) ** m /. (2. *. m +. 1.);;


let pi_approx n =
  4. *. sum_list_tr (efficient_map summand (0 -- n));;


let n = int_of_string Sys.argv.(1);;
Printf.printf "%F\n" (pi_approx n);;

Just for reference, here are the measured times on my machine:

❯❯❯ time ocaml/main 10000000
3.14159275359
ocaml/main 10000000  3,33s user 0,30s system 99% cpu 3,625 total

❯❯❯ time cpp/main 10000000
3.14159
cpp/main 10000000  0,17s user 0,00s system 99% cpu 0,174 total

For completeness, let me state that the first helper function, an equivalent to Python's range, comes from this SO thread, and that this is run using OCaml version 4.01.0, installed via MacPorts on a Darwin 13.1.0.

like image 376
logc Avatar asked May 06 '14 08:05

logc


People also ask

What does the sum function return in OCaml?

So this function returns: In other words, this function returns the OCaml integer representation of the sum a + b. This function is (+)! (It's actually more subtle than this - to perform the mathematics quickly, OCaml uses the x86 addressing hardware in a way that probably wasn't intended by the designers of the x86.)

Why is this function inlined in OCaml?

If you look into pervasives.ml you'll see why: OCaml has inlined this function. Inlining - taking a function and expanding it from its definition - is sometimes a performance win, because it avoids the overhead of an extra function call, and it can lead to more opportunities for the optimizer to do its thing.

Does OCaml have a polymorphic max function?

As a short aside, if you type this into the OCaml interactive toplevel (as above), you'll notice that OCaml decides that this function is polymorphic, with the following type: And indeed OCaml lets you use max on any type:

Is it possible to use OCaml strings in C?

However, the flip side is that you need to be aware of this if you pass an OCaml string to some C native code: if it contains ASCII NUL, then the C str* functions will fail on it. Secondly we have the header.


3 Answers

As I noted in a comment, OCaml's float are boxed, which puts OCaml to a disadvantage compared to Clang.

However, I may be noticing another typical rough edge trying OCaml after Haskell: if I see what your program is doing, you are creating a list of stuff, to then map a function on that list and finally fold it into a result.

In Haskell, you could more or less expect such a program to be automatically “deforested” at compile-time, so that the resulting generated code was an efficient implementation of the task at hand.

In OCaml, the fact that functions can have side-effects, and in particular functions passed to high-order functions such as map and fold, means that it would be much harder for the compiler to deforest automatically. The programmer has to do it by hand.

In other words: stop building huge short-lived data structures such as 0 -- n and (efficient_map summand (0 -- n)). When your program decides to tackle a new summand, make it do all it wants to do with that summand in a single pass. You can see this as an exercise in applying the principles in Wadler's article (again, by hand, because for various reasons the compiler will not do it for you despite your program being pure).


Here are some results:

$ ocamlopt v2.ml
$ time ./a.out 1000000
3.14159165359

real    0m0.020s
user    0m0.013s
sys     0m0.003s
$ ocamlopt v1.ml
$ time ./a.out 1000000
3.14159365359

real    0m0.238s
user    0m0.204s
sys     0m0.029s

v1.ml is your version. v2.ml is what you might consider an idiomatic OCaml version:

let rec q_pi_approx p n acc =
  if n = p
  then acc
  else q_pi_approx (succ p) n (acc +. (summand p))

let n = int_of_string Sys.argv.(1);;

Printf.printf "%F\n" (4. *. (q_pi_approx 0 n 0.));;

(reusing summand from your code)

It might be more accurate to sum from the last terms to the first, instead of from the first to the last. This is orthogonal to your question, but you may consider it as an exercise in modifying a function that has been forcefully made tail-recursive. Besides, the (-1.) ** m expression in summand is mapped by the compiler to a call to the pow() function on the host, and that's a bag of hurt you may want to avoid.

like image 141
Pascal Cuoq Avatar answered Oct 18 '22 06:10

Pascal Cuoq


I've also tried several variants, here are my conclusions:

  1. Using arrays
  2. Using recursion
  3. Using imperative loop

Recursive function is about 30% more effective than array implementation. Imperative loop is approximately as much effective as a recursion (maybe even little slower).

Here're my implementations:

Array:

open Core.Std

let pi_approx n =
  let f m = (-1.) ** m /. (2. *. m +. 1.) in
  let qpi = Array.init n ~f:Float.of_int |>
            Array.map ~f |>
            Array.reduce_exn ~f:(+.) in
  qpi *. 4.0

Recursion:

let pi_approx n =
  let rec loop n acc m =
    if m = n
    then acc *. 4.0
    else
      let acc = acc +. (-1.) ** m /. (2. *. m +. 1.) in
      loop n acc (m +. 1.0) in
  let n = float_of_int n in
  loop n 0.0 0.0

This can be further optimized, by moving local function loop outside, so that compiler can inline it.

Imperative loop:

let pi_approx n =
  let sum = ref 0. in
  for m = 0 to n -1 do
    let m = float_of_int m in
    sum := !sum +. (-1.) ** m /. (2. *. m +. 1.)
  done;
  4.0 *. !sum

But, in the code above creating a ref to the sum will incur boxing/unboxing on each step, that we can further optimize this code by using float_ref trick:

type float_ref = { mutable value : float}

let pi_approx n =
  let sum = {value = 0.} in
  for m = 0 to n - 1 do
    let m = float_of_int m in
    sum.value <- sum.value +. (-1.) ** m /. (2. *. m +. 1.)
  done;
  4.0 *. sum.value

Scoreboard

for-loop (with float_ref) : 1.0
non-local recursion       : 0.89
local recursion           : 0.86
Pascal's version          : 0.77
for-loop (with float ref) : 0.62
array                     : 0.47
original                  : 0.08

Update

I've updated the answer, as I've found a way to give 40% speedup (or 33% in comparison with @Pascal's answer.

like image 35
ivg Avatar answered Oct 18 '22 06:10

ivg


I would like to add that although floats are boxed in OCaml, float arrays are unboxed. Here is a program that builds a float array corresponding to the Leibnitz sequence and uses it to approximate π:

open Array

let q_pi_approx n =
  let summand n  =
    let m = float_of_int n
    in (-1.) ** m /. (2. *. m +. 1.) in
  let a = Array.init n summand in
  Array.fold_left (+.) 0. a

let n = int_of_string Sys.argv.(1);;
Printf.printf "%F\n" (4. *. (q_pi_approx n));;

Obviously, it is still slower than a code that doesn't build any data structure at all. Execution times (the version with array is the last one):

time ./v1 10000000
3.14159275359

real    0m2.479s
user    0m2.380s
sys 0m0.104s

time ./v2 10000000
3.14159255359

real    0m0.402s
user    0m0.400s
sys 0m0.000s

time ./a 10000000
3.14159255359

real    0m0.453s
user    0m0.432s
sys 0m0.020s
like image 40
Zoyd Avatar answered Oct 18 '22 06:10

Zoyd