Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I use a recursion scheme to express this probability distribution in Haskell

This question is part theory / part implementation. Background assumption: I'm using the monad-bayes library to represent probability distributions as monads. A distribution p(a|b) can be represented as a function MonadDist m => b -> m a.

Suppose I have a conditional probability distribution s :: MonadDist m => [Char] -> m Char. I want to get a new probability distribution sUnrolled :: [Char] -> m [Char], defined mathematically (I think) as:

sUnrolled(chars|st) = 
              | len(chars)==1 -> s st
              | otherwise -> s(chars[-1]|st++chars[:-1]) * sUnrolled(chars[:-1]|st)

Intuitively it's the distribution you get by taking st :: [Char], sampling a new char c from s st, feeding st++[c] back into s, and so on. I believe iterateM s is more or less what I want. To make it a distribution we could actually look at, let's say that if we hit a certain character, we stop. Then iterateMaybeM works.

Theory Question: For various reasons, it would be really useful if I could express this distribution in more general terms, for instance in a way that generalized to the stochastic construction of a tree given a stochastic coalgebra. It looks like I have some sort of anamorphism here (I realize that the mathematical definition looks like a catamorphism, but in code I want to build up strings, not deconstruct them into probabilities) but I can't quite work out the details, not least because of the presence of the probability monad.

Practical Question: it would also be useful to implement this in Haskell in a way that used the recursion schemes library, for instance.

like image 944
Reuben Avatar asked Apr 17 '18 04:04

Reuben


1 Answers

I'm not smart enough to thread monads through the recursion schemes, so I relied on recursion-schemes-ext, which has the anaM function for running anamorphisms with monadic actions attached.

I did a (really ugly) proof of concept here:

{-# LANGUAGE FlexibleContexts #-}
import Data.Functor.Foldable (ListF(..), Base, Corecursive)
import Data.Functor.Foldable.Exotic (anaM)
import System.Random

s :: String -> IO (Maybe Char)
s st = do
  continue <- getStdRandom $ randomR (0, 2000 :: Int)
  if continue /= 0
    then do
    getStdRandom (randomR (0, length st - 1)) >>= return . Just . (st !!)
    else return Nothing


result :: (Corecursive t, Traversable (Base t), Monad m) => (String -> m (Base t String)) -> String -> m t
result f = anaM f

example :: String -> IO (Base String String)
example st = maybe Nil (\c -> Cons c $ c:st) <$> s st

final :: IO String
final = result example "asdf"

main = final >>= print

A couple of notes

  1. I mocked out your s function, since I'm not familiar with monad-bayes
  2. Since our final list is inside a monad, we have to construct it strictly. This forces us to make a finite list (I allowed my s function to randomly stop at around 2000 characters).

EDIT:

Below is a modified version that confirms that other recursive structures (in this case, a binary tree) can be spawned by the result function. Note the type of final and the value of example are the only two bits of the previous code that have changed.

{-# LANGUAGE FlexibleContexts, TypeFamilies #-}
import Data.Functor.Foldable (ListF(..), Base, Corecursive(..))
import Data.Functor.Foldable.Exotic (anaM)
import Data.Monoid
import System.Random

data Tree a = Branch a (Tree a) (Tree a) | Leaf
  deriving (Show, Eq)
data TreeF a b = BranchF a b b | LeafF

type instance Base (Tree a) = TreeF a
instance Functor Tree where
  fmap f (Branch a left right) = Branch (f a) (f <$> left) (f <$> right)
  fmap f Leaf = Leaf
instance Functor (TreeF a) where
  fmap f (BranchF a left right) = BranchF a (f left) (f right)
  fmap f LeafF = LeafF
instance Corecursive (Tree a) where
  embed LeafF = Leaf
  embed (BranchF a left right) = Branch a left right
instance Foldable (TreeF a) where
  foldMap f LeafF = mempty
  foldMap f (BranchF a left right) = (f left) <> (f right)
instance Traversable (TreeF a) where
  traverse f LeafF = pure LeafF
  traverse f (BranchF a left right) = BranchF a <$> f left <*> f right

s :: String -> IO (Maybe Char)
s st = do
  continue <- getStdRandom $ randomR (0, 1 :: Int)
  if continue /= 0
    then getStdRandom (randomR (0, length st - 1)) >>= return . Just . (st !!)
    else return Nothing


result :: (Corecursive t, Traversable (Base t), Monad m) => (String -> m (Base t String)) -> String -> m t
result f = anaM f

example :: String -> IO (Base (Tree Char) String)
example st = maybe LeafF (\c -> BranchF c (c:st) (c:st)) <$> s st

final :: IO (Tree Char)
final = result example "asdf"

main = final >>= print
like image 111
rprospero Avatar answered Oct 04 '22 20:10

rprospero