Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Writing efficient iterative loop for ST monad

go worker tail-recursive loop pattern seems to work very well for writing pure code. What would be the equivalent way to write that kind of loop for the ST monad? More specifically, I want to avoid new heap allocation in the loop iterations. My guess is it involves either CPS transformation or fixST to re-write the code such that all the values that are changing across the loop are passed across each iteration, thus making register locations (or stack in case of spill) available for those values across iteration. I have a simplified example below (don't try running it - it will likely crash with segmentation fault!) involving a function called findSnakes which has a go worker pattern but the changing state values are not passed through accumulator arguments:

{-# LANGUAGE BangPatterns #-}
module Test where

import Data.Vector.Unboxed.Mutable as MU
import Data.Vector.Unboxed as U hiding (mapM_)
import Control.Monad.ST as ST
import Control.Monad.Primitive (PrimState)
import Control.Monad as CM (when,forM_)
import Data.Int

type MVI1 s  = MVector (PrimState (ST s)) Int

-- function to find previous y
findYP :: MVI1 s -> Int -> Int -> ST s Int
findYP fp k offset = do
              y0 <- MU.unsafeRead fp (k+offset-1) >>= \x -> return $ 1+x
              y1 <- MU.unsafeRead fp (k+offset+1)
              if y0 > y1 then return y0
              else return y1
{-#INLINE findYP #-}

findSnakes :: Vector Int32 -> MVI1 s ->  Int -> Int -> (Int -> Int -> Int) -> ST s ()
findSnakes a fp !k !ct !op = go 0 k
     where
           offset=1+U.length a
           go x k'
            | x < ct = do
                yp <- findYP fp k' offset
                MU.unsafeWrite fp (k'+offset) (yp + k')
                go (x+1) (op k' 1)
            | otherwise = return ()
{-#INLINE findSnakes #-}

Looking at cmm output in ghc 7.6.1 (with my limited knowledge of cmm - please correct me if I got it wrong), I see this kind of call flow, with loop in s1tb_info (which causes heap allocation and heap check in each iteration):

findSnakes_info -> a1_r1qd_info -> $wa_r1qc_info (new stack allocation, SpLim check)
-> s1sy_info -> s1sj_info: if arg > 1 then s1w8_info else R1 (can't figure out 
what that register points to)

-- I am guessing this one below is for go loop
s1w8_info -> s1w7_info (big heap allocation, HpLim check) -> s1tb_info: if arg >= 1
then s1td_info else R1

s1td_info (big heap allocation, HpLim check) -> if arg >= 1 then s1tb_info
(a loop) else s1tb_info (after executing a different block of code)

My guess is that the check of the form arg >= 1 in cmm code is to determine if go loop has terminated or not. If that is correct, it seems unless go loop is rewritten to pass yp across loop, heap allocation will happen across loop for new values (I am guessing yp is causing that heap allocation). What would be an efficient way to write go loop in above example? I guess yp will have to be passed as an argument in go loop, or equivalent way through fixST or CPS transformation. I can't think of a good way to rewrite go loop above to remove heap allocations, and will appreciate help with it.

like image 386
Sal Avatar asked Jun 15 '13 18:06

Sal


1 Answers

I rewrote your functions to avoid any explicit recursion and removed some redundant operations computing the offsets. This compiles to much nicer core than your original functions.

Core, by the way, is probably the better way to analyze your compiled code for this kind of profiling. Use ghc -ddump-simpl to see the generated core output, or tools like ghc-core

import Control.Monad.Primitive                                                                               
import Control.Monad.ST                                                                                      
import Data.Int                                                                                              
import qualified Data.Vector.Unboxed.Mutable as M                                                            
import qualified Data.Vector.Unboxed as U                                                                    

type MVI1 s  = M.MVector (PrimState (ST s)) Int                                                              

findYP :: MVI1 s -> Int -> ST s Int                                                                                                                                                     
findYP fp offset = do                                                                                      
    y0 <- M.unsafeRead fp (offset+0)                                                                       
    y1 <- M.unsafeRead fp (offset+2)                                                                       
    return $ max (y0 + 1) y1                                                                                  

findSnakes :: U.Vector Int32 -> MVI1 s ->  Int -> Int -> (Int -> Int -> Int) -> ST s ()                                                                                                         
findSnakes a fp k0 ct op = U.mapM_ writeAt $ U.iterateN ct (`op` 1) k0                                       
    where writeAt k = do    
              let offset = U.length a + k                                                                 
              yp <- findYP fp offset                                                                        
              M.unsafeWrite fp (offset + 1) (yp + k)

          -- or inline findYP manually
          writeAt k = do
             let offset = U.length a + k
             y0 <- M.unsafeRead fp (offset + 0)
             y1 <- M.unsafeRead fp (offset + 2)
             M.unsafeWrite fp (offset + 1) (k + max (y0 + 1) y1)

Also, you pass a U.Vector Int32 to findSnakes, only to compute its length and never use a again. Why not pass in the length directly?

like image 149
12 revs Avatar answered Oct 24 '22 08:10

12 revs