I have that Haskell function, that's causing more than 50% of all the allocations of my program, causing 60% of my run time to be taken by the GC. I run with a small stack (-K10K
) so there is no stack overflow, but can I make this function faster, with less allocation?
The goal here is to calculate the product of a matrix by a vector. I cannot use hmatrix
for example because this is part of a bigger function using the ad
Automatic Differentiation package, so I need to use lists of Num
. At runtime I suppose the use of the Numeric.AD
module means my types must be Scalar Double
.
listMProd :: (Num a) => [a] -> [a] -> [a]
listMProd mdt vdt = go mdt vdt 0
where
go [] _ s = [s]
go ls [] s = s : go ls vdt 0
go (y:ys) (x:xs) ix = go ys xs (y*x+ix)
Basically we loop through the matrix, multiplying and adding an accumulator until we reach the end of the vector, storing the result, then continuing restarting the vector again. I have a quickcheck
test verifying that I get the same result than the matrix/vector product in hmatrix.
I have tried with foldl
, foldr
, etc. Nothing I've tried makes the function faster (and some things like foldr
cause a space leak).
Running with profiling tells me, on top of the fact that this function is where most of the time and allocation is spent, that there are loads of Cells
being created, Cells
being a data type from the ad
package.
A simple test to run:
import Numeric.AD
main = do
let m :: [Double] = replicate 400 0.2
v :: [Double] = replicate 4 0.1
mycost v m = sum $ listMProd m v
mygrads = gradientDescent (mycost (map auto v)) (map auto m)
print $ mygrads !! 1000
This on my machine tells me GC is busy 47% of the time.
Any ideas?
A very simple optimization is to make the go
function strict by its accumulator parameter, because it's small, can be unboxed if a
is primitive and always needs to be fully evaluated:
{-# LANGUAGE BangPatterns #-}
listMProd :: (Num a) => [a] -> [a] -> [a]
listMProd mdt vdt = go mdt vdt 0
where
go [] _ !s = [s]
go ls [] !s = s : go ls vdt 0
go (y:ys) (x:xs) !ix = go ys xs (y*x+ix)
On my machine, it gives 3-4x speedup (compiled with -O2
).
On the other hand, intermediate lists shouldn't be strict so they could be fused.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With