Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does this State monad code works?

This code is from this article

I've been able to follow it until this part.

module Test where

type State = Int

data ST a = S (State -> (a, State))

apply        :: ST a -> State -> (a,State)
apply (S f) x = f x

fresh =  S (\n -> (n, n+1))

instance Monad ST where
    -- return :: a -> ST a
    return x   = S (\s -> (x,s))

    -- (>>=)  :: ST a -> (a -> ST b) -> ST b
    st >>= f   = S (\s -> let (x,s') = apply st s in apply (f x) s')

data Tree a = Leaf a | Node (Tree a) (Tree a) deriving (Show)


mlabel  :: Tree a -> ST (Tree (a,Int))
-- THIS IS THE PART I DON'T UNDERSTAND:
mlabel (Leaf x) = do n <- fresh
                     return (Leaf (x,n))
mlabel (Node l r) =  do l' <- mlabel l
                        r' <- mlabel r
                        return (Node l' r')

label t = fst (apply (mlabel t) 0)

tree = Node (Node (Leaf 'a') (Leaf 'b')) (Leaf 'c')

And label tree produces:

Node (Node (Leaf ('a',0)) (Leaf ('b',1))) (Leaf ('c',2))

I can see that >>= operator is the tool to 'chain' functions that return monads (or something like that).

And while I think I understand this code, I don't understand how this particular code works.

Specifically do n <- fresh. We haven't passed any argument to fresh, right? What does n <- fresh produces in that case? Absolutely don't understand that. Maybe it has something to do with currying?

like image 657
user1685095 Avatar asked Dec 02 '22 16:12

user1685095


2 Answers

Specifically do n <- fresh. We haven't passed any argument to fresh, right?

Exactly. We are writing for an argument that will be passed to fresh when we, for instance, do something like apply (mlabel someTree) 5. A nice exercise that will help you to see more clearly what is going on is first writing mlabel with explicit (>>=) instead of do-notation, and then replacing (>>=) and return with what the Monad instance says that they are.

like image 68
duplode Avatar answered Jan 02 '23 22:01

duplode


The key thing to realise is that do notation gets translated into Monad functions, so

do n <- fresh
   return (Leaf (x,n))

is short for

fresh >>= (\n -> 
           return (Leaf (x,n))  )

and

do l' <- mlabel l
   r' <- mlabel r
   return (Node l' r')

is short for

mlabel l >>= (\l' -> 
              mlabel r >>= (\r' ->
                            return (Node l' r') ))

This will hopefully allow you to continue figuring out the code's meaning, but for more help, you should read up on the do notation for Monads.

like image 44
AndrewC Avatar answered Jan 03 '23 00:01

AndrewC