Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding `bind` of `newtype Prob`

Tags:

haskell

Learn You a Haskell presents the Prob newtype:

newtype Prob a = Prob { getProb :: [(a,Rational)] } deriving Show

Here's Prob's definitions:

instance Functor Prob where  
    fmap f (Prob xs) = Prob $ map (\(x,p) -> (f x,p)) xs  

instance Monad Prob where
    return x = Prob [(x, 1%1)]
    p >>= f  = flatten (fmap f p)  

And then the supporting functions:

flatten :: Prob (Prob a) -> Prob a
flatten = Prob . convert . getProb 

convert :: [(Prob a, Rational)] -> [(a, Rational)]
convert = concat . (map f)

f :: (Prob a, Rational) -> [(a, Rational)]
f (p, r) = map (mult r) (getProb p)

mult :: Rational -> (a, Rational) -> (a, Rational)
mult r (x, y) = (x, r*y)

I wrote the flatten, convert, f, and mult functions, so I'm comfortable with them.

Then we apply >>= to the following example, involving a data type, Coin:

data Coin = Heads | Tails deriving (Show, Eq)

coin :: Prob Coin 
coin = Prob [(Heads, 1%2), (Tails, 1%2)]

loadedCoin :: Prob Coin
loadedCoin = Prob [(Heads, 1%10), (Tails, 9%10)]

LYAH says, If we throw all the coins at once, what are the odds of all of them landing tails?

flipTwo:: Prob Bool
flipTwo= do
  a <- coin       -- a has type `Coin`
  b <- loadedCoin -- similarly
  return (all (== Tails) [a,b])

Calling flipTwo returns:

Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}

flipTwo can be re-written with >>=:

flipTwoBind' :: Prob Bool
flipTwoBind' = coin >>= 
                    \x -> loadedCoin   >>= 
                                       \y -> return (all (== Tails) [x,y])

I'm not understanding the type of return (all (== Tails) [x,y]). Since it's the right-hand side of >>=, then its type must be a -> m b (where Monad m).

My understanding is that (all (==Tails) [x,y]) returns True or False, but how does return lead to the above result:

Prob {getProb = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]}?

like image 630
Kevin Meredith Avatar asked Mar 19 '23 14:03

Kevin Meredith


2 Answers

Note that the RHS of the >>= operator is a lambda expression, not the application of return:

\y -> return (all (== Tails) [x,y])

This lambda has type (Monad m) => a -> m b as expected.

Let's build up the type from the bottom:

As you say, all (== Tails) [x,y] returns True or False. In otherwords, its type is Bool.

Now, checking the type of return in ghci, we see that is:

Prelude> :t return
return :: Monad m => a -> m a

So return (all (==Tails) [x,y]) is type Monad m => m Boolean.

Wrapping this in a lambda, then gives the type (Monad m) => a -> m Boolean.

(Note that somewhere along the way, the compiler will deduce that the concrete monad type is Prob.)

You should think of return as taking a regular value and wrapping it into a Monad.

Addition:

Let's analyze the type of

flipTwoBind' = coin >>= 
                \x -> loadedCoin   >>= 
                                   \y -> return (all (== Tails) [x,y])

We start by noting that the outermost expression here is an application of (>>=) which has type:

Prelude> :t (>>=)
(>>=) :: Monad m => m a -> (a -> m b) -> m b

The LHS is coin which has type Prob Coin, so we immediately deduce that m is Prob and a is Coin. This means that the RHS must have type Coin -> Prob b for some type b. So let's look at the RHS now:

\x -> loadedCoin >>= \y -> return (all (== Tails) [x,y])

Here we have a lambda that returns the result of an application of (>>=), so the lambda has type

(Monad m) => a -> m b

This matches the expected type for the application of the first (>>=), so a here is Coin and m is Prob.

Now analyzing the inner application of (>>=), we see that its type is deduced to be

(>>=) :: Prob Coin -> (Prob -> Prob b) -> Prob b

We already analyzed the RHS of the second (>>=), and so b is deduced to be Bool.

(Note, this may not be the exact order that the compiler uses to deduce the types. It just happens to be the order which my thoughts followed as I analyzed the types for this answer.)

like image 167
Code-Apprentice Avatar answered Mar 31 '23 18:03

Code-Apprentice


(I'll call your coin fairCoin) You have:

flipTwoBind' :: Prob Bool
flipTwoBind' = fairCoin     >>=  g   where
   g x       = loadedCoin   >>=  h   where
     h y     = return z              where 
       z     = all (== Tails) [x,y]

From the type of (>>=) we get:

fairCoin ::         Prob Coin
(>>=) :: Monad m => m    a    ->  (a -> m b) -> m b       | m ~ Prob, a ~ Coin
                    fairCoin  >>=    g       :: m b       | g :: Coin -> Prob b
flipTwoBind'                              :: Prob Bool    | m ~ Prob, b ~ Bool

so that g :: Coin -> Prob Bool and g x :: Prob Bool provided that x :: Coin.

Since g x = loadedCoin >>= h, we have

loadedCoin ::       Prob Coin
(>>=) :: Monad m => m    a    ->  (a -> m b) -> m    b 
                  loadedCoin  >>=    h       :: Prob Bool

So, h :: Coin -> Prob Bool, z :: Bool and return z :: Prob Bool:

all ::  (a -> Bool) -> [a] -> Bool
all        p           []  :: Bool

return :: (Monad m) => a -> m a
z      ::           Bool
return                 z :: m Bool           | m ~ Prob so return z :: Prob Bool

Since Prob a is essentially a tagged assoc-list of pairs of a outcomes and their corresponding probabilities, Prob Bool is a list of pairings of Bool outcomes and their probabilities.


Translated with the specific Prob monadic code, inlining all the functions, flipTwoBind' becomes

flipTwoBind' = fairCoin     >>=  g
   = flatten (fmap g fairCoin)
   = Prob . convert . getProb $ 
             Prob $ map (\(x,p) -> (g x,p)) $ getProb fairCoin
   = Prob . concat . map (\(x,p) -> map (\(x, y) -> (x, p*y)) $ getProb x)
                  . map (\(x,p) -> (g x,p)) $ getProb fairCoin

(see how nicely the Prob and getProb cancel each other there on the inside...).

Switching to plain list-based code (with gL xs = getProb (g (Prob xs)) and fairCoinL = getProb fairCoin etc.), it is equivalent to

   = concat . map (\(x,p) -> map (second (p*)) x)
            . map (\(x,p) -> (gL x,p)) $ fairCoinL
   = concat . map (\(x,p) -> map (second (p*)) $ gL x) $ fairCoinL
   = [(v,p*q) | (x,p) <- fairCoinL, (v,q) <- gL x]
   = ....
   = [(z,r)   | (x,p) <- [(Heads,   1%2),  (Tails,   1%2 )],   -- do a <- fairCoin
                (y,q) <- [(Heads, p*1%10), (Tails, p*9%10)],   --    b <- loadedCoin
                (z,r) <- [(all (== Tails) [x,y],   q*1%1 )] ]  --    return ... all ...
   = [(False,1 % 20),(False,9 % 20),(False,1 % 20),(True,9 % 20)]

Of course the one before last line in the derivation above could equally be just written as

   = [(all (== Tails) [x,y], q)                                -- ... all ... <$>
              | (x,p) <- [(Heads,   1%2),  (Tails,   1%2 )],   --   fairCoin <*>
                (y,q) <- [(Heads, p*1%10), (Tails, p*9%10)] ]  --   loadedCoin

because (>>= return . f) === fmap f.

like image 38
Will Ness Avatar answered Mar 31 '23 17:03

Will Ness