Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can I get KnownNat n to imply KnownNat (n * 3), etc?

I'm working with data types of this shape, using V from linear:

type Foo n = V (n * 3) Double -> Double

Having it fixed on n is pretty important, because I want to be able to ensure that I'm passing in the right number of elements at compile-time. This is a part of my program that already works well, independent of what I'm doing here.

For any KnownNat n, I can generate a Foo n satisfying the behavior that my program needs. For the purposes of this question it can be something silly like

mkFoo :: KnownNat (n * 3) => Foo n
mkFoo = sum

Or for a more meaningful example, it can generate a random V of the same length and use dot on the two. The KnownNat constraint here is redundant, but in reality, it's needed to do make a Foo. I make one Foo and use it for my entire program (or with multiple inputs), so this guarantees me that whenever I use it, I'm using on things with the same length, and on things that the structure of the Foo dictates.

And finally, I have a function that makes inputs for a Foo:

bar :: KnownNat (n * 3) => Proxy n -> [V (n * 3) Double]

bar is actually the reason why i'm using n * 3 as a type function, instead of just manually expanding it out. The reason is that bar might do its job by using three vectors of length n and appending them all together as a vector of length n * 3. Also, n is a much more meaningful parameter to the function, semantically, than n * 3. This also lets me disallow improper values like n's that aren't multiples of 3, etc.

Now, before, everything worked fine as long as I defined a type synonym at the beginning:

type N = 5

And I can just then pass in Proxy :: Proxy N to bar, and use mkFoo :: Foo N. And everything worked fine.

-- works fine
doStuff :: [Double]
doStuff = let inps = bar (Proxy :: Proxy N)
          in  map (mkFoo :: Foo N) inps

But now I want to be able to adjust N during runtime by loading information from a file, or from command line arguments.

I tried doing it by calling reflectNat:

doStuff :: Integer -> Double
doStuff n = reflectNat 5 $ \pn@(Proxy :: Proxy n) ->
              let inps = bar (Proxy :: Proxy n)
              in  map (mkFoo :: Foo n) inps

But...bar and mkFoo require KnownNat (n * 3), but reflectNat just gives me KnownNat n.

Is there any way I can generalize the proof that reflectNat gives me to satisfy foo ?

like image 798
Justin L. Avatar asked Sep 29 '15 08:09

Justin L.


2 Answers

So, three months later, I have been going back and forth on good ways to accomplish this, but I finally settled on an actual very succinct trick that doesn't require any throwaway newtypes; it involves using a Dict from the constraints library; you could easily write a:

natDict :: KnownNat n => Proxy n -> Dict (KnownNat n)
natDict _ = Dict

triple :: KnownNat n => Proxy n -> Dict (KnownNat (n * 3))
triple p = reifyNat (natVal p * 3) $
             \p3 -> unsafeCoerce (natDict p3)

And once you get Dict (KnownNat (n * 3), you can pattern match on it to get the (n * 3) instance in scope:

case triple (Proxy :: Proxy n) of
  Dict -> -- KnownNat (n * 3) is in scope

You can actually set these up as generic, too:

addNats :: (KnownNat n, KnownNat m) => Proxy n -> Proxy m -> Dict (KnownNat (n * m))
addNats px py = reifyNat (natVal px + natVal py) $
                  \pz -> unsafeCoerce (natDict pz)

Or, you can make them operators and you can use them to "combine" Dicts:

infixl 6 %+
infixl 7 %*
(%+) :: Dict (KnownNat n) -> Dict (KnownNat m) -> Dict (KnownNat (n + m))
(%*) :: Dict (KnownNat n) -> Dict (KnownNat m) -> Dict (KnownNat (n * m))

And you can do things like:

case d1 %* d2 %+ d3 of
  Dict -> -- in here, KnownNat (n1 * n2 + n3) is in scope

I've wrapped this up in a nice library, typelits-witnesses that I've been using. Thank you all for your help!

like image 137
Justin L. Avatar answered Oct 22 '22 16:10

Justin L.


I post another answer as it is more direct, editing the previous won't make sense.

In fact using the trick (popularised if not invented by Edward Kmett), from reflections reifyNat:

{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
import GHC.TypeLits
import Data.Proxy
import Unsafe.Coerce

newtype MagicNat3 r = MagicNat3 (forall (n :: Nat). KnownNat (n * 3) => Proxy n -> r)

trickValue :: Integer -> Integer
trickValue = (*3)

-- No type-level garantee that the function will be called with (n * 3)
-- you have to believe us
trick :: forall a n. KnownNat n => Proxy n -> (forall m. KnownNat (m * 3) => Proxy m -> a) -> a
trick p f = unsafeCoerce (MagicNat3 f :: MagicNat3 a) (trickValue (natVal p)) Proxy

test :: forall m. KnownNat (m * 3) => Proxy m -> Integer
test _ = natVal (Proxy :: Proxy (m * 3))

So when you run it:

λ *Main > :t trick (Proxy :: Proxy 4) test :: Integer
trick (Proxy :: Proxy 4) test :: Integer :: Integer
λ *Main > trick (Proxy :: Proxy 4) test :: Integer
12

The trick is based on the fact that in GHC the one member class dictionaries (like KnownNat) are represented by the member itself. In KnownNat situation it turns out to be Integer. So we just unsafeCoerce it there. Universal quantification makes it sound from the outside.

like image 24
phadej Avatar answered Oct 22 '22 17:10

phadej