Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optimize a list function that creates too much garbage (not stack overflow)

I have that Haskell function, that's causing more than 50% of all the allocations of my program, causing 60% of my run time to be taken by the GC. I run with a small stack (-K10K) so there is no stack overflow, but can I make this function faster, with less allocation?

The goal here is to calculate the product of a matrix by a vector. I cannot use hmatrix for example because this is part of a bigger function using the ad Automatic Differentiation package, so I need to use lists of Num. At runtime I suppose the use of the Numeric.AD module means my types must be Scalar Double.

listMProd :: (Num a) => [a] -> [a] -> [a]
listMProd mdt vdt = go mdt vdt 0
  where
    go [] _  s = [s]
    go ls [] s = s : go ls vdt 0
    go (y:ys) (x:xs) ix = go ys xs (y*x+ix)

Basically we loop through the matrix, multiplying and adding an accumulator until we reach the end of the vector, storing the result, then continuing restarting the vector again. I have a quickcheck test verifying that I get the same result than the matrix/vector product in hmatrix.

I have tried with foldl, foldr, etc. Nothing I've tried makes the function faster (and some things like foldr cause a space leak).

Running with profiling tells me, on top of the fact that this function is where most of the time and allocation is spent, that there are loads of Cells being created, Cells being a data type from the ad package.

A simple test to run:

import Numeric.AD

main = do
    let m :: [Double] = replicate 400 0.2
        v :: [Double] = replicate 4 0.1
        mycost v m = sum $ listMProd m v 
        mygrads = gradientDescent (mycost (map auto v)) (map auto m)
    print $ mygrads !! 1000

This on my machine tells me GC is busy 47% of the time.

Any ideas?

like image 984
JP Moresmau Avatar asked Sep 24 '15 15:09

JP Moresmau


1 Answers

A very simple optimization is to make the go function strict by its accumulator parameter, because it's small, can be unboxed if a is primitive and always needs to be fully evaluated:

{-# LANGUAGE BangPatterns #-}
listMProd :: (Num a) => [a] -> [a] -> [a]
listMProd mdt vdt = go mdt vdt 0
  where
    go [] _  !s = [s]
    go ls [] !s = s : go ls vdt 0
    go (y:ys) (x:xs) !ix = go ys xs (y*x+ix)

On my machine, it gives 3-4x speedup (compiled with -O2).

On the other hand, intermediate lists shouldn't be strict so they could be fused.

like image 177
Yuuri Avatar answered Sep 28 '22 00:09

Yuuri