Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is `filterM + mapM_` so much slower than `mapM_ + when`, with large lists?

I don't know very much about how Haskell optimization works internally but I've been using filters quite much hoping that they are optimized into something equivalent to a simple if in C++. For example

mapM_ print $ filter (\n -> n `mod` 2 == 0) [0..10]

will compile into equivalent of

for (int i = 0; i < 10; i++)
    if (i%2 == 0)
        printf("%d\n", i);

With long lists (10 000 000 elements) it seems to be true for a basic filter but there is a huge difference if I use the monadic filterM. I wrote a piece of code for this speed testing and it's obvious that the usage of filterM lasts much longer (250x) than a more imperative approach using when.

import Data.Array.IO
import Control.Monad
import System.CPUTime

main :: IO ()
main = do
  start <- getCPUTime
  arr <- newArray (0, 100) 0 :: IO (IOUArray Int Int)
  let
    okSimple i =
      i < 100

    ok i = do
      return $ i < 100
    -- -- of course we don't need IO for a simple i < 100
    -- -- but my goal is to ask for the contents of the array, e.g.
    -- ok i = do
    --   current <- readArray arr (i `mod` 101)
    --   return$ i `mod` 37 > current `mod` 37
    
    write :: Int -> IO ()
    write i =
      writeArray arr (i `mod` 101) i

    writeIfOkSimple :: Int -> IO ()
    writeIfOkSimple i =
      when (okSimple i) $ write i

    writeIfOk :: Int -> IO ()
    writeIfOk i =
      ok i >>= (\isOk -> when isOk $ write i)

  -------------------------------------------------------------------
  ---- these four methods have approximately same execution time ----
  ---- (but the last one is executed on 250 times shorter list)  ----
  -------------------------------------------------------------------
  -- mapM_ write$ filter okSimple [0..10000000*250] -- t = 20.694
  -- mapM_ writeIfOkSimple [0..10000000*250]        -- t = 20.698
  -- mapM_ writeIfOk [0..10000000*250]              -- t = 20.669
  filterM ok [0..10000000] >>= mapM_ write          -- t = 17.200

  -- evaluate array
  elems <- getElems arr
  print $ sum elems

  end <- getCPUTime
  print $ fromIntegral (end - start) / (10^12)

My question is: shouldn't both approaches (using writeIfOk / using filterM ok and write) compile into the same code (iterate list, ask for condition, write data)? If not, can I do something (rewrite code, add compilation flags, use inline pragma or something) to make them computationally equivalent or should I always use when when performance is critical?

like image 637
Matej Vargovčík Avatar asked Mar 24 '21 12:03

Matej Vargovčík


1 Answers

Boiling this question down to its essence, your asking about the difference between

f (filter g xs)

and

f =<< filterM (pure . g) xs

This basically comes down to laziness. filter g xs produces its result incrementally as it's demanded, only walking xs far enough to find the next element of the result. filterM is defined something like this:

filterM _p [] = pure []
filterM p (x : xs)
  = liftA2 (\q r -> if q then x : r else r)
           (p x)
           (filterM p xs)

Since IO is a "strict" applicative, this will not produce anything at all until it's walked the whole list, accumulating the p x results in memory.

like image 173
dfeuer Avatar answered Sep 28 '22 01:09

dfeuer