go
worker tail-recursive loop pattern seems to work very well for writing pure code. What would be the equivalent way to write that kind of loop for the ST
monad? More specifically, I want to avoid new heap allocation in the loop iterations. My guess is it involves either CPS transformation
or fixST
to re-write the code such that all the values that are changing across the loop are passed across each iteration, thus making register locations (or stack in case of spill) available for those values across iteration. I have a simplified example below (don't try running it - it will likely crash with segmentation fault!) involving a function called findSnakes
which has a go
worker pattern but the changing state values are not passed through accumulator arguments:
{-# LANGUAGE BangPatterns #-}
module Test where
import Data.Vector.Unboxed.Mutable as MU
import Data.Vector.Unboxed as U hiding (mapM_)
import Control.Monad.ST as ST
import Control.Monad.Primitive (PrimState)
import Control.Monad as CM (when,forM_)
import Data.Int
type MVI1 s = MVector (PrimState (ST s)) Int
-- function to find previous y
findYP :: MVI1 s -> Int -> Int -> ST s Int
findYP fp k offset = do
y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x
y1 <- MU.unsafeRead fp (k+offset+1)
if y0 > y1 then return y0
else return y1
{-#INLINE findYP #-}
findSnakes :: Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp !k !ct !op = go 0 k
where
offset=1+U.length a
go x k'
| x < ct = do
yp <- findYP fp k' offset
MU.unsafeWrite fp (k'+offset) (yp + k')
go (x+1) (op k' 1)
| otherwise = return ()
{-#INLINE findSnakes #-}
Looking at cmm
output in ghc 7.6.1
(with my limited knowledge of cmm
- please correct me if I got it wrong), I see this kind of call flow, with loop in s1tb_info
(which causes heap allocation and heap check in each iteration):
findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check)
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out
what that register points to)
-- I am guessing this one below is for go loop
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1
then s1td_info else R1
s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info
(a loop) else s1tb_info (after executing a different block of code)
My guess is that the check of the form arg >= 1
in cmm
code is to determine if go
loop has terminated or not. If that is correct, it seems unless go
loop is rewritten to pass yp
across loop, heap allocation will happen across loop for new values (I am guessing yp
is causing that heap allocation). What would be an efficient way to write go
loop in above example? I guess yp
will have to be passed as an argument in go
loop, or equivalent way through fixST
or CPS
transformation. I can't think of a good way to rewrite go
loop above to remove heap allocations, and will appreciate help with it.
I rewrote your functions to avoid any explicit recursion and removed some redundant operations computing the offsets. This compiles to much nicer core than your original functions.
Core, by the way, is probably the better way to analyze your compiled code for this kind of profiling. Use ghc -ddump-simpl
to see the generated core output, or tools like ghc-core
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Int
import qualified Data.Vector.Unboxed.Mutable as M
import qualified Data.Vector.Unboxed as U
type MVI1 s = M.MVector (PrimState (ST s)) Int
findYP :: MVI1 s -> Int -> ST s Int
findYP fp offset = do
y0 <- M.unsafeRead fp (offset+0)
y1 <- M.unsafeRead fp (offset+2)
return $ max (y0 + 1) y1
findSnakes :: U.Vector Int32 -> MVI1 s -> Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0
where writeAt k = do
let offset = U.length a + k
yp <- findYP fp offset
M.unsafeWrite fp (offset + 1) (yp + k)
-- or inline findYP manually
writeAt k = do
let offset = U.length a + k
y0 <- M.unsafeRead fp (offset + 0)
y1 <- M.unsafeRead fp (offset + 2)
M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1)
Also, you pass a U.Vector Int32
to findSnakes
, only to compute its length and never use a
again. Why not pass in the length directly?
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With