Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Does Haskell provides a way to evaluate IO monad immediately?

Tags:

haskell

I am currently making a ray tracing program with Haskell. As I am a very beginner of Haskell, I don't understand the evaluation strategy of IO monad clearly.

The problem is the memory usage of a long list of "IO a", which is "IO Vec" in my code.

Each element of the list is computed by a recursive function that compute IO Vec which represents the color for a pixel. Therefore, the length of the list is equals to width x height.

In addition, I take multiple samples for a pixels. As a whole, the function radiance to compute pixel value is called width x height x samples times.

First I was implemented this program simply by using list comprehension. The code is like,

main = do
    ...
    let ray = (compute ray for every pair of [0..w-1], [0..h-1]
    pixels <- sequence [ (sumOfRadiance scene ray samples) | ray <- rays]

In my understanding, as pixels is not used before it is written to a file, Haskell stores some data for function call inside pixels which is an array of IO Vec. Finally, memory consumption increases by calling recursive function radiance to compute pixel values.

If I change the program to evaluate the pixel value one by one using unsafePerformIO can prevent this strange use of memory space.

main = do
    ...
    let ray = (compute ray for every pair of [0..w-1], [0..h-1]
    let pixels = [ (unsafePerformIO (sumOfRadiance scene ray samples)) | ray <- rays]

I know unsafePerformIO is a bad solution, so I'd like to know if Haskell provides another way to evaluate inside of IO monad immediately. The following is the whole of my code (Sorry, it's a bit long...)

Thank you for your help.

-- Small path tracing with Haskell
import System.Environment
import System.Random.Mersenne
import System.IO.Unsafe
import Control.Monad
import Codec.Picture
import Data.Time
import qualified Data.Word as W
import qualified Data.Vector.Storable as V

-- Parameters
eps :: Double
eps = 1.0e-4

inf :: Double
inf = 1.0e20

nc :: Double
nc  = 1.0

nt :: Double
nt  = 1.5

-- Vec
data Vec = Vec (Double, Double, Double) deriving (Show)
instance (Num Vec) where
    (Vec (x, y, z)) + (Vec (a, b, c)) = Vec (x + a, y + b, z + c)
    (Vec (x, y, z)) - (Vec (a, b, c)) = Vec (x - a, y - b, z - c)
    (Vec (x, y, z)) * (Vec (a, b, c)) = Vec (x * a, y * b, z * c)
    abs = undefined
    signum = undefined
    fromInteger x = Vec (dx, dx, dx) where dx = fromIntegral x

x :: Vec -> Double
x (Vec (x, _, _)) = x

y :: Vec -> Double
y (Vec (_, y, _)) = y

z :: Vec -> Double
z (Vec (_, _, z)) = z

mul :: Vec -> Double -> Vec
mul (Vec (x, y, z)) s = Vec (x * s, y * s, z * s)

dot :: Vec -> Vec -> Double
dot (Vec (x, y, z)) (Vec (a, b, c))  = x * a + y * b + z * c

norm :: Vec -> Vec
norm (Vec (x, y, z)) = Vec (x * invnrm, y * invnrm, z * invnrm)
    where invnrm = 1 / sqrt (x * x + y * y + z * z)

cross :: Vec -> Vec -> Vec
cross (Vec (x, y, z)) (Vec (a, b, c)) = Vec (y * c - b * z, z * a - c * x, x * b - a * y)

-- Ray
data Ray = Ray (Vec, Vec) deriving (Show)

org :: Ray -> Vec
org (Ray (org, _)) = org

dir :: Ray -> Vec
dir (Ray (_, dir)) = dir

-- Material
data Refl = Diff
          | Spec
          | Refr
          deriving Show

-- Sphere
data Sphere = Sphere (Double, Vec, Vec, Vec, Refl) deriving (Show)

rad :: Sphere -> Double
rad  (Sphere (rad, _, _, _, _   )) = rad

pos :: Sphere -> Vec
pos  (Sphere (_  , p, _, _, _   )) = p

emit :: Sphere -> Vec
emit (Sphere (_  , _, e, _, _   )) = e

col :: Sphere -> Vec
col  (Sphere (_  , _, _, c, _   )) = c

refl :: Sphere -> Refl
refl (Sphere (_  , _, _, _, refl)) = refl

intersect :: Sphere -> Ray -> Double
intersect sp ray =
    let op  = (pos sp) - (org ray)
        b   = op `dot` (dir ray)
        det = b * b - (op `dot` op) + ((rad sp) ** 2)
    in
        if det < 0.0
            then inf
            else
                let sqdet = sqrt det
                    t1    = b - sqdet
                    t2    = b + sqdet
                in ansCheck t1 t2
                      where ansCheck t1 t2
                                | t1 > eps  = t1
                                | t2 > eps  = t2
                                | otherwise = inf

-- Scene
type Scene = [Sphere]
sph :: Scene
sph = [ Sphere (1e5,  Vec ( 1e5+1,  40.8, 81.6),    Vec (0.0, 0.0, 0.0), Vec (0.75, 0.25, 0.25),  Diff)   -- Left
      , Sphere (1e5,  Vec (-1e5+99, 40.8, 81.6),    Vec (0.0, 0.0, 0.0), Vec (0.25, 0.25, 0.75),  Diff)   -- Right
      , Sphere (1e5,  Vec (50.0, 40.8,  1e5),       Vec (0.0, 0.0, 0.0), Vec (0.75, 0.75, 0.75),  Diff)   -- Back
      , Sphere (1e5,  Vec (50.0, 40.8, -1e5+170),   Vec (0.0, 0.0, 0.0), Vec (0.0, 0.0, 0.0),     Diff)   -- Front
      , Sphere (1e5,  Vec (50, 1e5, 81.6),          Vec (0.0, 0.0, 0.0), Vec (0.75, 0.75, 0.75),  Diff)   -- Bottom
      , Sphere (1e5,  Vec (50,-1e5+81.6,81.6),      Vec (0.0, 0.0, 0.0), Vec (0.75, 0.75, 0.75),  Diff)   -- Top
      , Sphere (16.5, Vec (27, 16.5, 47),           Vec (0.0, 0.0, 0.0), Vec (1,1,1) `mul` 0.999, Spec)   -- Mirror
      , Sphere (16.5, Vec (73, 16.5, 78),           Vec (0.0, 0.0, 0.0), Vec (1,1,1) `mul` 0.999, Refr)   -- Glass
      , Sphere (600,  Vec (50, 681.6 - 0.27, 81.6), Vec (12, 12, 12),    Vec (0, 0, 0),           Diff) ] -- Light

-- Utility functions
clamp :: Double -> Double
clamp = (max 0.0) . (min 1.0)

isectWithScene :: Scene -> Ray -> (Double, Int)
isectWithScene scene ray = foldr1 (min) $ zip [ intersect sph ray | sph <- scene ] [0..]

nextDouble :: IO Double
nextDouble = randomIO

lambert :: Vec -> Double -> Double -> (Vec, Double)
lambert n r1 r2 =
    let th  = 2.0 * pi * r1
        r2s = sqrt r2
        w = n
        u = norm $ (if (abs (x w)) > eps then Vec (0, 1, 0) else Vec (1, 0, 0)) `cross` w
        v = w `cross` u
        uu = u `mul` ((cos th) * r2s)
        vv = v `mul` ((sin th) * r2s)
        ww = w `mul` (sqrt (1.0 - r2))
        rdir = norm (uu + vv + ww)
    in (rdir, 1)

reflect :: Vec -> Vec -> (Vec, Double)
reflect v n =
    let rdir = v - (n `mul` (2.0 * n `dot` v))
    in (rdir, 1)

refract :: Vec -> Vec -> Vec -> Double -> (Vec, Double)
refract v n orn rr =
    let (rdir, _) = reflect v orn
        into = (n `dot` orn) > 0
        nnt  = if into then (nc / nt) else (nt / nc)
        ddn  = v `dot` orn
        cos2t = 1.0 - nnt * nnt * (1.0 - ddn * ddn)
    in
        if cos2t < 0.0
            then (rdir, 1.0)
            else
                let tdir = norm $ ((v `mul` nnt) -) $ n `mul` ((if into then 1 else -1) * (ddn * nnt + (sqrt cos2t)))
                    a = nt - nc
                    b = nt + nc
                    r0 = (a * a) / (b * b)
                    c = 1.0 - (if into then -ddn else (tdir `dot` n))
                    re = r0 + (1 - r0) * (c ** 5)
                    tr = 1.0 - re
                    pp = 0.25 + 0.5 * re
                in
                    if rr < pp
                         then (rdir, (pp / re))
                         else (tdir, ((1.0 - pp) / tr))

radiance :: Scene -> Ray -> Int -> IO Vec
radiance scene ray depth = do
    let (t, i) = (isectWithScene scene ray)
    if inf <= t
        then return (Vec (0, 0, 0))
        else do
            r0 <- nextDouble
            r1 <- nextDouble
            r2 <- nextDouble
            let obj = (scene !! i)
            let c = col obj
            let prob = (max (x c) (max (y c) (z c)))
            if depth >= 5 && r0 >= prob
                then return (emit obj)
                else do
                    let rlt = if depth < 5 then 1 else prob
                    let f = (col obj)
                    let d = (dir ray)
                    let x = (org ray) + (d `mul` t)
                    let n = norm $ x - (pos obj)
                    let orn = if (d `dot` n) < 0.0  then n else (-n)
                    let (ndir, pdf) = case (refl obj) of
                            Diff -> (lambert orn r1 r2)
                            Spec -> (reflect d orn)
                            Refr -> (refract d n orn r1)
                    nextRad <- (radiance scene (Ray (x, ndir)) (succ depth))
                    return $ ((emit obj) + ((f * nextRad) `mul` (1.0 / (rlt * pdf))))

toByte :: Double -> W.Word8
toByte x = truncate (((clamp x) ** (1.0 / 2.2)) * 255.0) :: W.Word8

accumulateRadiance :: Scene -> Ray -> Int -> Int -> IO Vec
accumulateRadiance scene ray d m = do
    let rays = take m $ repeat ray
    pixels <- sequence [radiance scene r 0 | r <- rays]
    return $ (foldr1 (+) pixels) `mul` (1 / fromIntegral m)

main :: IO ()
main = do
    args <- getArgs
    let argc = length args
    let w   = if argc >= 1 then (read (args !! 0)) else 400 :: Int
    let h   = if argc >= 2 then (read (args !! 1)) else 300 :: Int
    let spp = if argc >= 3 then (read (args !! 2)) else 4   :: Int

    startTime <- getCurrentTime

    putStrLn "-- Smallpt.hs --"
    putStrLn $ "  width = " ++ (show w)
    putStrLn $ " height = " ++ (show h)
    putStrLn $ "    spp = " ++ (show spp)

    let dw = fromIntegral w :: Double
    let dh = fromIntegral h :: Double

    let cam = Ray (Vec (50, 52, 295.6), (norm $ Vec (0, -0.042612, -1)));
    let cx  = Vec (dw * 0.5135 / dh, 0.0, 0.0)
    let cy  = (norm $ cx `cross` (dir cam)) `mul` 0.5135
    let dirs = [ norm $ (dir cam) + (cy `mul` (y / dh  - 0.5)) + (cx `mul` (x / dw - 0.5)) | y <- [dh-1,dh-2..0], x <- [0..dw-1] ]
    let rays = [ Ray ((org cam) + (d `mul` 140.0), (norm d)) | d <- dirs ]

    let pixels = [ (unsafePerformIO (accumulateRadiance sph r 0 spp)) | r <- rays ]

    let pixelData = map toByte $! pixels `seq` (foldr (\col lst -> [(x col), (y col), (z col)] ++ lst) [] pixels)
    let pixelBytes = V.fromList pixelData :: V.Vector W.Word8
    let img = Image { imageHeight = h, imageWidth = w, imageData = pixelBytes } :: Image PixelRGB8
    writePng "image.png" img

    endTime <- getCurrentTime
    print $ diffUTCTime endTime startTime
like image 358
tatsy Avatar asked Dec 25 '22 15:12

tatsy


1 Answers

First, I think there is an error. When you talk about going from

pixels <- sequence [ (sumOfRadiance scene ray samples) | ray <- rays]

to

pixels <- sequence [ (unsafePerformIO (sumOfRadiance scene ray samples)) | ray <- rays]

that doesn't make sense. The types shouldn't match up -- sequence only makes sense if you are combining a bunch of things of type m a. It would be correct to do

let pixels = [ unsafePerformIO (sumOfRadiance scene ray samples) | ray <- rays ]

I will somewhat cavalierly assume that that is what you did and you simply made a mistake when entering your question.

If this is the case, then what you are actually looking for is a way to execute IO actions more lazily, not more immediately. The sequence call forces all the actions to be run right then, whereas the unsafePerformIO version simply creates a list of un-run actions (and indeed the list itself is generated lazily so it doesn't exist all at once), and the actions are run individually as their results are needed.

It appears that the reason you need IO is to generate random numbers. Randomness can be kind of a pain -- usually MonadRandom does the job, but it still creates a sequential dependence between actions and may still not be lazy enough (I'd give it a try -- if you use it you get reproducibility -- the same seed gives the same results, even after refactorings that respect the monad laws).

If MonadRandom doesn't work and you need to generate random numbers in a more on-demand way, the way would be to make your own randomness monad which does the same thing as your unsafePerformIO solution, but in a way that is properly encapsulated. I'm going to show you the way I consider to be the Haskell Way To Cheat. First, a lovely pure implementation sketch:

-- A seed tells you how to generate random numbers
data Seed = ...
splitSeed :: Seed -> (Seed, Seed)
random :: Seed -> Double

-- A Cloud is a probability distribution of a's, or an a which
-- depends on a random seed.  This monad is just as lazy as a
-- pure computation.
newtype Cloud a = Cloud { runCloud :: Seed -> a }
    deriving (Functor)

instance Monad Cloud where
    return = Cloud . const
    m >>= f = Cloud $ \seed ->
        let (seed1, seed2) = splitSeed seed in
        runCloud (f (runCloud m seed1)) seed2

(I think I got that right. The point is that at every bind you split the seed in two and pass one to the left and the other to the right.)

Now this is a perfectly pure implementation of randomness... with a couple catches. (1) there is no non-trivial splitSeed which will strictly respect the monad laws, and (2) even if we allow the laws to be broken, random number generators based on splitting can be pretty slow. But if we give up determinism, if all we care about is that we get a good sampling from the distribution rather than the exact same result, then we don't need to strictly respect the monad laws. And at that point we cheat and pretend there is a suitable Seed type:

data Seed = Seed
splitSeed Seed = (Seed, Seed)

-- Always NOINLINE functions with unsafePerformIO to keep the 
-- optimizer from messing with you.
{-# NOINLINE random #-}
random Seed = unsafePerformIO randomIO

We should hide this inside a module to keep the abstraction barrier clear. Cloud and runCloud should not be exposed since they allow us to violate purity; expose only

runCloudIO :: Cloud a -> IO a
runCloudIO = return . runCloud

which doesn't technically need IO, but communicates that this will not be deterministic. Then you can build up whatever you need as a value in the Cloud monad, and run it once in your main program.

You might ask why we have a Seed type at all if it doesn't have any information. Well, I think splitSeed is just a nod to purity and isn't actually doing anything -- you could remove it -- but we need Cloud to be a function type so that the implicit caching of laziness doesn't break our semantics. Otherwise

let foo = random in liftM2 (,) foo foo

would always return a pair with two identical components, since the random value was really associated with foo. I am not sure about these things since at this point we are at war with the optimizer, it takes some experimentation.

Happy cheating. :-)

like image 156
luqui Avatar answered May 06 '23 20:05

luqui