Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Building values dynamically with GADTs using Data Kinds

Why is it harder to build values with datakinds, while it's relatively easy to pattern match with them?

{-# LANGUAGE  KindSignatures
            , GADTs
            , DataKinds
            , Rank2Types
 #-}

data Nat = Zero | Succ Nat

data Direction = Center | Up | Down | UpDown deriving (Show, Eq)

data Chain :: Nat -> Nat -> * -> * where
    Nil    :: Chain Zero Zero a
    AddUp  :: a -> Chain nUp nDn a -> Chain (Succ nUp) nDn a
    AddDn  :: a -> Chain nUp nDn a -> Chain nUp (Succ nDn) a
    AddUD  :: a -> Chain nUp nDn a -> Chain (Succ nUp) (Succ nDn) a
    Add    :: a -> Chain nUp nDn a -> Chain nUp nDn a

lengthChain :: Num b => Chain (Succ Zero) (Succ Zero) a -> b
lengthChain = lengthChain'

lengthChain' :: forall (t::Nat) (t1::Nat) a b. Num b => Chain t t1 a -> b
lengthChain' Nil = 0
lengthChain' (Add   _ rest) = 1 + lengthChain' rest
lengthChain' (AddUp _ rest) = 1 + lengthChain' rest
lengthChain' (AddDn _ rest) = 1 + lengthChain' rest
lengthChain' (AddUD _ rest) = 1 + lengthChain' rest

chainToList :: Chain (Succ Zero) (Succ Zero) a -> [(a, Direction)]
chainToList = chainToList'

chainToList' :: forall (t::Nat) (t1::Nat) a. Chain t t1 a -> [(a, Direction)]
chainToList' Nil = []
chainToList' (Add a rest) = (a, Center):chainToList' rest
chainToList' (AddUp a rest) = (a, Up):chainToList' rest
chainToList' (AddDn a rest) = (a, Down):chainToList' rest
chainToList' (AddUD a rest) = (a, UpDown):chainToList' rest

listToChain :: forall (t::Nat) (t1::Nat) b. [(b, Direction)] -> Chain t t1 b
listToChain ((x, Center): xs) = Add x (listToChain xs)
listToChain ((x, Up):xs) = AddUp x (listToChain xs)
listToChain ((x, Down): xs) = AddDn x (listToChain xs)
listToChain ((x, UpDown): xs) = AddUD x (listToChain xs)
listToChain _ = Nil

I am trying to build a data type to control a structure similar to a list, with the difference that we might add arrows to elements. Furthermore, I require that some functions operate only on lists where the number of Up arrows and Down arrows is exactly equal to 1.

In the above code, the function listToChain fails to compile, while chainToList compiles normally. How can we fix the listToChain code?

like image 895
banx Avatar asked May 04 '14 00:05

banx


1 Answers

If you think about it for a bit, you'll see that there is no way the type of your listToChain can ever work because it takes in values of (b, Direction) that have no type level information of direction and it should still somehow figure out the direction-indexed type of the resulting Chain at compile time. That's clearly impossible since at run-time the values could be inputted by the user or read from a socket etc.

You need to either skip the intermediate list and build up your chain directly from compile-time verified values or your can wrap the resulting chain in an existential type and perform a run-time check to reify the existential to a more precise type.

So, given an existential wrapper like

data SomeChain a where
    SomeChain :: Chain nu nd a -> SomeChain a

you can implement listToChain as

listToChain :: [(b, Direction)] -> SomeChain b
listToChain ((x, Center): xs) = withSome (SomeChain . Add x)   (listToChain xs)
listToChain ((x, Up):xs)      = withSome (SomeChain . AddUp x) (listToChain xs)
listToChain ((x, Down): xs)   = withSome (SomeChain . AddDn x) (listToChain xs)
listToChain ((x, UpDown): xs) = withSome (SomeChain . AddUD x) (listToChain xs)
listToChain _                 = SomeChain Nil

using the helper function withSome for more convenient wrapping and unwrapping of the existential.

withSome :: (forall nu nd. Chain nu nd b -> r) -> SomeChain b -> r
withSome f (SomeChain c) = f c

Now we have an existential which we can pass around that hides the precise up and down types. When we want to call a function like lengthChain that expects specific up and down counts we need to validate the contents at run-time. One way to do this is to define a type-class.

class ChainProof pnu pnd where
    proveChain :: Chain nu nd b -> Maybe (Chain pnu pnd b)

The proveChain function takes a chain of any nu and nd and tries to prove that it conforms to the specific pnu and pnd. Implementing ChainProof requires a bit of repetitive boilerplate but it can then provide proof for any desired combination of ups and downs in addition to the one-one case we need for lengthChain.

instance ChainProof Zero Zero where
    proveChain Nil          = Just Nil
    proveChain (Add a rest) = Add a <$> proveChain rest
    proveChain _            = Nothing

instance ChainProof u Zero => ChainProof (Succ u) Zero where
    proveChain (Add a rest)   = Add a   <$> proveChain rest
    proveChain (AddUp a rest) = AddUp a <$> proveChain rest
    proveChain _              = Nothing

instance ChainProof Zero d => ChainProof Zero (Succ d) where
    proveChain (Add a rest)   = Add a   <$> proveChain rest
    proveChain (AddDn a rest) = AddDn a <$> proveChain rest
    proveChain _              = Nothing

instance (ChainProof u (Succ d), ChainProof (Succ u) d, ChainProof u d) => ChainProof (Succ u) (Succ d) where
    proveChain (Add a rest)   = Add a   <$> proveChain rest
    proveChain (AddUp a rest) = AddUp a <$> proveChain rest
    proveChain (AddDn a rest) = AddDn a <$> proveChain rest
    proveChain (AddUD a rest) = AddUD a <$> proveChain rest
    proveChain _              = Nothing

The above requires the language extensions MultiParamTypeClasses and FlexibleContexts and I'm using <$> from Control.Applicative.

Now we can use the proving mechanism to create a safe wrapper for any function expecting specific up and down counts

safe :: ChainProof nu nd => (Chain nu nd b -> r) -> SomeChain b -> Maybe r
safe f = withSome (fmap f . proveChain)

This might seem like an unsatisfactory solution since we still need to handle the failure case (i.e. Nothing) but at least the check is only required at the top level. Inside the given f we have static guarantees about the structure of the chain and don't need to do any additional validation.

Alternative solution

The above solution, while simple to implement, has to traverse and re-construct the whole chain every time it is validated. Another option is to store the up and down counts as singleton naturals in the existential.

data SNat :: Nat -> * where
    SZero :: SNat Zero
    SSucc :: SNat n -> SNat (Succ n)

data SomeChain a where
    SomeChain :: SNat nu -> SNat nd -> Chain nu nd a -> SomeChain a

The SNat type is the value level equivalent of the Nat kind so that for each type of kind Nat there is exactly one value of type SNat which means that even when the type t of SNat t is erased, we can fully recover it by pattern matching on the value. By extension this means that we can recover the full type of Chain in the existential by merely pattern matching on the naturals without having to traverse the chain itself.

Building the chain gets a little bit more verbose

listToChain :: [(b, Direction)] -> SomeChain b
listToChain ((x, Center): xs) = case listToChain xs of
    SomeChain u d c -> SomeChain u d (Add x c)
listToChain ((x, Up):xs)      = case listToChain xs of
    SomeChain u d c -> SomeChain (SSucc u) d (AddUp x c)
listToChain ((x, Down): xs)   = case listToChain xs of
    SomeChain u d c -> SomeChain u (SSucc d) (AddDn x c)
listToChain ((x, UpDown): xs) = case listToChain xs of
    SomeChain u d c -> SomeChain (SSucc u) (SSucc d) (AddUD x c)
listToChain _                 = SomeChain SZero SZero Nil

But on the other hand, the proof get shorter (although with somewhat hairy type signatures).

proveChain :: forall pnu pnd b. (ProveNat pnu, ProveNat pnd) => SomeChain b -> Maybe (Chain pnu pnd b)
proveChain (SomeChain (u :: SNat u) (d :: SNat d) c)
    = case (proveNat u :: Maybe (Refl u pnu), proveNat d :: Maybe (Refl d pnd)) of
        (Just Refl, Just Refl) -> Just c
        _ -> Nothing

This uses ScopedTypeVariables to explicitly choose the type-class instances for ProveNat we want to use. If we get proof that the naturals match the requested values then the type-checker is happy to let us return Just c without examining it further.

ProveNat is defined as

{-# LANGUAGE PolyKinds #-}

data Refl a b where
    Refl :: Refl a a

class ProveNat n where
    proveNat :: SNat m -> Maybe (Refl m n)

The Refl type (reflexivity) is a commonly used pattern to make the type checker unify two unknown types when we pattern match on the Refl constructor (and PolyKinds allows it to be generic to any kind, letting us use it with Nats). So while proveNat accepts forall m. SNat m if we can pattern match on Just Refl afterwards, we (and more importantly, the type-checker) can be sure that m and n are actually the same type.

The instances for ProveNat are pretty simple but require, again, some explicit types to help inference.

instance ProveNat Zero where
    proveNat SZero = Just Refl
    proveNat _ = Nothing

instance ProveNat n => ProveNat (Succ n) where
    proveNat m@(SSucc _) = proveNat' m where
        proveNat' :: forall p. ProveNat n => SNat (Succ p) -> Maybe (Refl (Succ p) (Succ n))
        proveNat' (SSucc p) = case proveNat p :: Maybe (Refl p n) of
            Just Refl -> Just Refl
            _         -> Nothing
    proveNat _ = Nothing
like image 80
shang Avatar answered Sep 17 '22 23:09

shang