Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the correct way to perform constant-space nested loops in Haskell?

Tags:

There are two obvious, "idiomatic" ways to perform nested loops in Haskell: using the list monad or using forM_ to replace traditional fors. I've set a benchmark to determine if those are compiled to tight loops:

import Control.Monad.Loop
import Control.Monad.Primitive
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector.Unboxed.Mutable as MV
import qualified Data.Vector.Unboxed as V

times = 100000
side  = 100

-- Using `forM_` to replace traditional fors
test_a mvec = 
    forM_ [0..times-1] $ \ n -> do
        forM_ [0..side-1] $ \ y -> do
            forM_ [0..side-1] $ \ x -> do
                MV.write mvec (y*side+x) 1

-- Using the list monad to replace traditional forms
test_b mvec = sequence_ $ do
    n <- [0..times-1]
    y <- [0..side-1]
    x <- [0..side-1]
    return $ MV.write mvec (y*side+x) 1

main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    -- test_a mvec
    -- test_b mvec
    vec' <- V.unsafeFreeze mvec :: IO (V.Vector Int)
    print $ V.sum vec'

This test creates a 100x100 vector, writes 1 to each index using nested loop and repeats that 100k times. Compiling those with just ghc -O2 test.hs -o test (ghc version 7.8.4), the results are: 3.853s for the forM_ version and 10.460s for the list monad. In order to provide a reference, I also programmed this test in JavaScript:

var side  = 100;
var times = 100000;
var vec   = [];

for (var i=0; i<side*side; ++i)
    vec.push(0);

for (var n=0; n<times; ++n)
    for (var y=0; y<side; ++y)
        for (var x=0; x<side; ++x)
            vec[x+y*side] = 1;

var s = 0;
for (var i=0; i<side*side; ++i)
    s += vec[i];

console.log(s);

This equivalent JavaScript program takes 1s to complete, beating Haskell's unboxed vectors, which is unusual, suggesting that Haskell is not running the loop in constant space, but doing allocations instead. I've then found a library that claims to provide type-guaranteed tight loops Control.Monad.Loop:

-- Using `for` from Control.Monad.Loop
test_c mvec = exec_ $ do
    n <- for 0 (< times) (+ 1)
    x <- for 0 (< side) (+ 1)
    y <- for 0 (< side) (+ 1)
    liftIO (MV.write mvec (y*side+x) 1)

Which runs in 1s. That library isn't very used and far from idiomatic, though, so, what is the idiomatic way to get fast constant-space bidimensional computations? (Note this isn't a case for REPA as I want to perform arbitrary IO actions on the grid.)

like image 421
MaiaVictor Avatar asked Sep 02 '15 05:09

MaiaVictor


People also ask

What is the rules of nested for loop?

In a nested loop, a break statement only stops the loop it is placed in. Therefore, if a break is placed in the inner loop, the outer loop still continues. However, if the break is placed in the outer loop, all of the looping stops.

Is there loop in Haskell?

Recursion is important to Haskell because unlike imperative languages, you do computations in Haskell by declaring what something is instead of declaring how you get it. That's why there are no while loops or for loops in Haskell and instead we many times have to use recursion to declare what something is.


1 Answers

Writing tight mutating code with GHC can be tricky sometimes. I'm going to write about a couple of different things, probably in a manner that is more rambling and tl;dr than I would prefer.

For starters, we should use GHC 7.10 in any case, since otherwise the forM_ and list monad solutions never fuse.

Also, I replaced MV.write with MV.unsafeWrite, partly because it's faster, but more importantly it reduces some of the clutter in the resultant Core. From now on runtime statistics refer to code with unsafeWrite.

The dreaded let floating

Even with GHC 7.10, we should first notice all those [0..times-1] and [0..side-1] expressions, because they will ruin performance every time if we don't take necessary steps. The issue is that they are constant ranges, and -ffull-laziness (which is enabled by default on -O) floats them out to top level. This prevents list fusion, and iterating over an Int# range is cheaper than iterating over a list of boxed Int-s anyway, so it's a really bad optimization.

Let's see some runtimes in seconds for the unchanged (aside from using unsafeWrite) code. ghc -O2 -fllvm is used, and I use +RTS -s for timing.

test_a: 1.6
test_b: 6.2
test_c: 0.6

For GHC Core viewing I used ghc -O2 -ddump-simpl -dsuppress-all -dno-suppress-type-signatures.

In the case of test_a, the [0..99] ranges are lifted out:

main4 :: [Int]
main4 = eftInt 0 99 -- means "enumFromTo" for Int.

although the outermost [0..9999] loop is fused into a tail-recursive helper:

letrec {
          a3_s7xL :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a3_s7xL =
            \ (x_X5zl :: Int#) (s1_X4QY :: State# RealWorld) ->
              case a2_s7xF 0 s1_X4QY of _ { (# ipv2_a4NA, ipv3_a4NB #) ->
              case x_X5zl of wild_X1S {
                __DEFAULT -> a3_s7xL (+# wild_X1S 1) ipv2_a4NA;
                99999 -> (# ipv2_a4NA, () #)
              }
              }; }

In the case of test_b, again only the [0..99] are lifted. However, test_b is much slower, because it has to build and sequence actual [IO ()] lists. At least GHC is sensible enough to only build a single [IO ()] for the two inner loops, and then perform sequencing it 10000 times.

 let {
          lvl7_s4M5 :: [IO ()]
          lvl7_s4M5 = -- omitted
        letrec {
          a2_s7Av :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Av =
            \ (x_a5xi :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Au
                  :: [IO ()] -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Au =
                  \ (ds_a4Nu :: [IO ()]) (eta1_X1c :: State# RealWorld) ->
                    case ds_a4Nu of _ {
                      [] ->
                        case x_a5xi of wild1_X1y {
                          __DEFAULT -> a2_s7Av (+# wild1_X1y 1) eta1_X1c;
                          99999 -> (# eta1_X1c, () #)
                        };
                      : y_a4Nz ys_a4NA ->
                        case (y_a4Nz `cast` ...) eta1_X1c
                        of _ { (# ipv2_a4Nf, ipv3_a4Ng #) ->
                        a3_s7Au ys_a4NA ipv2_a4Nf
                        }
                    }; } in
              a3_s7Au lvl7_s4M5 eta_B1; } in
-- omitted

How can we remedy this? We could nuke the problem with {-# OPTIONS_GHC -fno-full-laziness #-}. This indeed helps a lot in our case:

test_a: 0.5
test_b: 0.48
test_c: 0.5

Alternatively, we could fiddle around with INLINE pragmas. Apparently inlining functions after the let floating is done preserves good performance. I found that GHC inlines our test functions even without a pragma, but an explicit pragma causes it to inline only after let floating. For example, this results in good performance without -fno-full-laziness:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE test_a #-}

But inlining too early results in poor performance:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE [~2] test_a #-} -- "inline before the first phase please"

The problem with this INLINE solution is that it's rather fragile in the face of GHC's floating onslaught. For example, manual inlining does not preserve performance. The following code is slow because similarly to INLINE [~2] it gives GHC a chance to float out:

main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1    

So what should we do?

First, I think using -fno-full-laziness is a perfectly viable and even preferable option for those who'd like to write high performance code and have a good idea what they are doing. For example, it's used in unordered-containers. With it we have more precise control over sharing, and we can always just float out or inline manually.

For more regular code, I believe there's nothing wrong with using Control.Monad.Loop or any other package that provides the functionality. Many Haskell users are not scrupulous about depending on small "fringe" libraries. We can also just reimplement for, in a desired generality. For instance, the following performs just as well as the other solutions:

for :: Monad m => a -> (a -> Bool) -> (a -> a) -> (a -> m ()) -> m ()
for init while step body = go init where
  go !i | while i = body i >> go (step i)
  go i = return ()
{-# INLINE for #-}

Looping in really constant space

I was at first very puzzled by the +RTS -s data on heap allocation. test_a allocated non-trivially with -fno-full-laziness, and also test_c without full laziness, and these allocations scaled linearly with the number of times iterations, but test_b with full laziness allocated only for the vector:

-- with -fno-full-laziness, no INLINE pragmas
test_a: 242,521,008 bytes
test_b: 121,008 bytes
test_c: 121,008 bytes -- but 240,120,984 with full laziness!

Also, INLINE pragmas for test_c did not help at all in this case.

I spent some time trying to find signs of heap allocation in the Core for the relevant programs, without success, until the realization struck me: GHC stack frames are on the heap, including the frames of the main thread, and the functions that were doing heap allocation were essentially running the thrice-nested loops in at most three stack frames. The heap allocation registered by +RTS -s is just the constant popping and pushing of stack frames.

This is pretty much apparent from the Core for the following code:

{-# OPTIONS_GHC -fno-full-laziness #-}

-- ...

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
main = do
    let vec = V.generate (side*side) (const 0)
    mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_a mvec

Which I'm including here in its glory. Feel free to skip.

main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5HK :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vr { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vr) of _ {
      False ->
        case newByteArray# 80000 (s_a5HK `cast` ...)
        of _ { (# ipv_a5fv, ipv1_a5fw #) ->
        letrec {
          $s$wa_s8jS
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8jS =
            \ (sc_s8jO :: Int#)
              (sc1_s8jP :: Int#)
              (sc2_s8jR :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jP 10000) of _ {
                False -> (# sc2_s8jR, I# sc_s8jO #);
                True ->
                  case writeIntArray# ipv1_a5fw sc_s8jO 0 (sc2_s8jR `cast` ...)
                  of s'#_a5Gn { __DEFAULT ->
                  $s$wa_s8jS (+# sc_s8jO 1) (+# sc1_s8jP 1) (s'#_a5Gn `cast` ...)
                  }
              }; } in
        case $s$wa_s8jS 0 0 (ipv_a5fv `cast` ...)
        -- end of vector creation -------------------

        of _ { (# ipv6_a4Hv, ipv7_a4Hw #) ->
        letrec {
          a2_s7MJ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7MJ =
            \ (x_a5Ho :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7ME :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7ME =
                  \ (x1_X5Id :: Int#) (eta1_XR :: State# RealWorld) ->
                    case ipv7_a4Hw of _ { I# dt4_a5x6 ->
                    case writeIntArray#
                           (ipv1_a5fw `cast` ...) (*# x1_X5Id 100) 1 (eta1_XR `cast` ...)
                    of s'#_a5Gn { __DEFAULT ->
                    letrec {
                      a4_s7Mz :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7Mz =
                        \ (x2_X5J8 :: Int#) (eta2_X1U :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5fw `cast` ...)
                                 (+# (*# x1_X5Id 100) x2_X5J8)
                                 1
                                 (eta2_X1U `cast` ...)
                          of s'#1_X5Hf { __DEFAULT ->
                          case x2_X5J8 of wild_X2o {
                            __DEFAULT -> a4_s7Mz (+# wild_X2o 1) (s'#1_X5Hf `cast` ...);
                            99 -> (# s'#1_X5Hf `cast` ..., () #)
                          }
                          }; } in
                    case a4_s7Mz 1 (s'#_a5Gn `cast` ...)
                    of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
                    case x1_X5Id of wild_X1e {
                      __DEFAULT -> a3_s7ME (+# wild_X1e 1) ipv2_a4QH;
                      99 -> (# ipv2_a4QH, () #)
                    }
                    }
                    }
                    }; } in
              case a3_s7ME 0 eta_B1 of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
              case x_a5Ho of wild_X1a {
                __DEFAULT -> a2_s7MJ (+# wild_X1a 1) ipv2_a4QH;
                99999 -> (# ipv2_a4QH, () #)
              }
              }; } in
        a2_s7MJ 0 (ipv6_a4Hv `cast` ...)
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wm, ww6_a5wn #) ->
                   : ww5_a5wm ww6_a5wn
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

We can also nicely demonstrate the allocation of frames the following way. Let's change test_a:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-50] $ \ x -> -- change here
                MV.unsafeWrite mvec (y*side+x) 1

Now the heap allocation stays exactly the same, because the innermost loop is tail-recursive and uses a single frame. With the following change, the heap allocation halves (to 124,921,008 bytes), because we push and pop half as many frames:

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-50] $ \ y -> -- change here
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1

test_b and test_c (with no full laziness) instead compile to code that uses a nested case construct inside a single stack frame, and walks over the indices to see which one should be incremented. See the Core for the following main:

{-# LANGUAGE BangPatterns #-} -- later I'll talk about this
{-# OPTIONS_GHC -fno-full-laziness #-}

main = do
    let vec = V.generate (side*side) (const 0)
    !mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int)
    test_c mvec

Voila:

main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5Iw :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vT { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vT) of _ {
      False ->
        case newByteArray# 80000 (s_a5Iw `cast` ...)
        of _ { (# ipv_a5g3, ipv1_a5g4 #) ->
        letrec {
          $s$wa_s8ji
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8ji =
            \ (sc_s8je :: Int#)
              (sc1_s8jf :: Int#)
              (sc2_s8jh :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jf 10000) of _ {
                False -> (# sc2_s8jh, I# sc_s8je #);
                True ->
                  case writeIntArray# ipv1_a5g4 sc_s8je 0 (sc2_s8jh `cast` ...)
                  of s'#_a5GP { __DEFAULT ->
                  $s$wa_s8ji (+# sc_s8je 1) (+# sc1_s8jf 1) (s'#_a5GP `cast` ...)
                  }
              }; } in
        case $s$wa_s8ji 0 0 (ipv_a5g3 `cast` ...)
        of _ { (# ipv6_a4MX, ipv7_a4MY #) ->
        case ipv7_a4MY of _ { I# dt4_a5xy ->
        -- end of vector creation

        letrec {
          a2_s7Q6 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Q6 =
            \ (x_a5HT :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Q5 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Q5 =
                  \ (x1_X5J9 :: Int#) (eta1_XP :: State# RealWorld) ->
                    letrec {
                      a4_s7MZ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7MZ =
                        \ (x2_X5Jl :: Int#) (s1_X4Xb :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5g4 `cast` ...)
                                 (+# (*# x1_X5J9 100) x2_X5Jl)
                                 1
                                 (s1_X4Xb `cast` ...)
                          of s'#_a5GP { __DEFAULT ->

                          -- the interesting part! ------------------
                          case x2_X5Jl of wild_X1y {
                            __DEFAULT -> a4_s7MZ (+# wild_X1y 1) (s'#_a5GP `cast` ...);
                            99 ->
                              case x1_X5J9 of wild1_X1o {
                                __DEFAULT -> a3_s7Q5 (+# wild1_X1o 1) (s'#_a5GP `cast` ...);
                                99 ->
                                  case x_a5HT of wild2_X1c {
                                    __DEFAULT -> a2_s7Q6 (+# wild2_X1c 1) (s'#_a5GP `cast` ...);
                                    99999 -> (# s'#_a5GP `cast` ..., () #)
                                  }
                              }
                          }
                          }; } in
                    a4_s7MZ 0 eta1_XP; } in
              a3_s7Q5 0 eta_B1; } in
        a2_s7Q6 0 (ipv6_a4MX `cast` ...)
        }
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wO, ww6_a5wP #) ->
                   : ww5_a5wO ww6_a5wP
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

I have to admit that I basically don't know why some code avoids stack frame creation and some doesn't. I suspect that inlining from "the inside" out helps, and a quick inspection informed me that Control.Monad.Loop uses a CPS encoding, which might be relevant here, although the Monad.Loop solution is sensitive to let floating, and I couldn't determine on short notice from the Core why test_c with let floating fails to run in a single stack frame.

Now, the performance benefit of running in a single stack frame is small. We've seen that test_b is only slightly faster than test_a. I include this detour in the answer because I found it edifying.

The state hack and strict bindings

The so-called state hack makes GHC aggressive in inlining into IO and ST actions. I think I should mention it here, because besides let floating this is the other thing that can thoroughly ruin performance.

The state hack is enabled with optimizations -O, and can possibly slow down programs asymptotically. A simple example from Reid Barton:

import Control.Monad
import Debug.Trace

expensive :: String -> String
expensive x = trace "$$$" x

main :: IO ()
main = do
  str <- fmap expensive getLine
  replicateM_ 3 $ print str

With GHC-7.10.2, this prints "$$$" once without optimizations but three times with -O2. And it seems that with GHC-7.10, we can't get rid of this behavior with -fno-state-hack (which is the subject of the linked ticket from Reid Barton).

Strict monadic bindings reliably get rid of this problem:

main :: IO ()
main = do
  !str <- fmap expensive getLine
  replicateM_ 3 $ print str

I think it's good habit to do strict bindings in IO and ST. And I have some experience (not definitive though; I'm far from being a GHC expert) that strict bindings are especially needed if we use -fno-full-laziness. Apparently full laziness can help get rid of some of the work duplication introduced by the inlining caused by the state hack; with test_b and no full laziness, omitting the strict binding on !mvec <- V.unsafeThaw vec caused a slight slowdown and extremely ugly Core output.

like image 110
András Kovács Avatar answered Nov 01 '22 02:11

András Kovács