Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding State Monad

Tags:

haskell

Looking at Learn You a Haskell's definition of the State Monad:

instance Monad (State s) where  
    return x = State $ \s -> (x,s)  
    (State h) >>= f = State $ \s -> let (a, newState) = h s  
                                        (State g) = f a  
                                    in  g newState  

I don't understand the types of h s and g newState in the lower right-hand side.

Can you please explain their types and what's going on?

like image 879
Kevin Meredith Avatar asked Dec 26 '22 05:12

Kevin Meredith


1 Answers

State s a is a naming of a function---the "state transformer function"

s -> (a, s)

In other words, it takes an input state s and modifies that state while also returning a result, a. This forms a really general framework of "pure state". If our state is an integer, we can write a function which updates that integer and returns the new value---this is like a unique number source.

upd :: Int -> (Int, Int)
upd s = let s' = s + 1 in (s', s')

Here, a and s end up being the same type.


Now this is all fine and good, except that we're in trouble if we'd like to get two fresh numbers. For that we must somehow run upd twice.

The final result is going to be another state transformer function, so we're looking for a "state transformer transformer". I'll call it compose:

compose :: (s -> (a, s))         -- the initial state transformer
        -> (a -> (s -> (b, s)))  -- a new state transformer, built using the "result"
                                 -- of the previous one
        -> (s -> (b, s))         -- the result state transformer

This is a little hairy looking, but honestly it's fairly easy to write this function. The types guide you to the answer:

compose f f' = \s -> let (a, s')  = f s
                         (b, s'') = f' a s'
                     in  (b, s'')

You'll notice that the s-typed variables, [s, s', s''] "flow downward" indicating that state moves from the first computation through the second leading to the result.

We can use compose to build a function which gets two unique numbers using upd

twoUnique :: Int -> ((Int, Int), Int)
twoUnique = compose upd (\a s -> let (a', s') = upd s in ((a, a'), s'))

These are the basics of State. The only difference is that we recognize there's a common pattern going on inside of the compose function and we extract it. That pattern looks like

(>>=) :: State s a     -> (a -> State s b   ) -> State s b
(>>=) :: (s -> (a, s)) -> (a -> (s -> (b, s)) -> (s -> (b, s))

It's implemented the same way, too. We just need to "wrap" and "unwrap" the State bit---that's the purpose of State and runState

State    :: (s -> (a, s)) -> State s a
runState :: State s a     -> (s -> (a, s))

Now we can take compose and compare it to (>>=)

compose f f'       =         \s -> let (a, s')  = f s
                                       (b, s'') =           f' a  s'
                                   in  (b, s'')

(>>=) (State f) f' = State $ \s -> let (a, s')  = f s
                                       (b, s'') = runState (f' a) s'
                                   in  (b, s'')
like image 125
J. Abrahamson Avatar answered Feb 14 '23 14:02

J. Abrahamson