I'm working with data types of this shape, using V
from linear
:
type Foo n = V (n * 3) Double -> Double
Having it fixed on n
is pretty important, because I want to be able to ensure that I'm passing in the right number of elements at compile-time. This is a part of my program that already works well, independent of what I'm doing here.
For any KnownNat n
, I can generate a Foo n
satisfying the behavior that my program needs. For the purposes of this question it can be something silly like
mkFoo :: KnownNat (n * 3) => Foo n
mkFoo = sum
Or for a more meaningful example, it can generate a random V
of the same length and use dot
on the two. The KnownNat
constraint here is redundant, but in reality, it's needed to do make a Foo
. I make one Foo
and use it for my entire program (or with multiple inputs), so this guarantees me that whenever I use it, I'm using on things with the same length, and on things that the structure of the Foo
dictates.
And finally, I have a function that makes inputs for a Foo
:
bar :: KnownNat (n * 3) => Proxy n -> [V (n * 3) Double]
bar is actually the reason why i'm using n * 3
as a type function, instead of just manually expanding it out. The reason is that bar
might do its job by using three vectors of length n
and appending them all together as a vector of length n * 3
. Also, n
is a much more meaningful parameter to the function, semantically, than n * 3
. This also lets me disallow improper values like n
's that aren't multiples of 3, etc.
Now, before, everything worked fine as long as I defined a type synonym at the beginning:
type N = 5
And I can just then pass in Proxy :: Proxy N
to bar
, and use mkFoo :: Foo N
. And everything worked fine.
-- works fine
doStuff :: [Double]
doStuff = let inps = bar (Proxy :: Proxy N)
in map (mkFoo :: Foo N) inps
But now I want to be able to adjust N
during runtime by loading information from a file, or from command line arguments.
I tried doing it by calling reflectNat
:
doStuff :: Integer -> Double
doStuff n = reflectNat 5 $ \pn@(Proxy :: Proxy n) ->
let inps = bar (Proxy :: Proxy n)
in map (mkFoo :: Foo n) inps
But...bar
and mkFoo
require KnownNat (n * 3)
, but reflectNat
just gives me KnownNat n
.
Is there any way I can generalize the proof that reflectNat
gives me to satisfy foo
?
So, three months later, I have been going back and forth on good ways to accomplish this, but I finally settled on an actual very succinct trick that doesn't require any throwaway newtypes; it involves using a Dict
from the constraints library; you could easily write a:
natDict :: KnownNat n => Proxy n -> Dict (KnownNat n)
natDict _ = Dict
triple :: KnownNat n => Proxy n -> Dict (KnownNat (n * 3))
triple p = reifyNat (natVal p * 3) $
\p3 -> unsafeCoerce (natDict p3)
And once you get Dict (KnownNat (n * 3)
, you can pattern match on it to get the (n * 3)
instance in scope:
case triple (Proxy :: Proxy n) of
Dict -> -- KnownNat (n * 3) is in scope
You can actually set these up as generic, too:
addNats :: (KnownNat n, KnownNat m) => Proxy n -> Proxy m -> Dict (KnownNat (n * m))
addNats px py = reifyNat (natVal px + natVal py) $
\pz -> unsafeCoerce (natDict pz)
Or, you can make them operators and you can use them to "combine" Dicts:
infixl 6 %+
infixl 7 %*
(%+) :: Dict (KnownNat n) -> Dict (KnownNat m) -> Dict (KnownNat (n + m))
(%*) :: Dict (KnownNat n) -> Dict (KnownNat m) -> Dict (KnownNat (n * m))
And you can do things like:
case d1 %* d2 %+ d3 of
Dict -> -- in here, KnownNat (n1 * n2 + n3) is in scope
I've wrapped this up in a nice library, typelits-witnesses that I've been using. Thank you all for your help!
I post another answer as it is more direct, editing the previous won't make sense.
In fact using the trick (popularised if not invented by Edward Kmett), from reflections reifyNat
:
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
import GHC.TypeLits
import Data.Proxy
import Unsafe.Coerce
newtype MagicNat3 r = MagicNat3 (forall (n :: Nat). KnownNat (n * 3) => Proxy n -> r)
trickValue :: Integer -> Integer
trickValue = (*3)
-- No type-level garantee that the function will be called with (n * 3)
-- you have to believe us
trick :: forall a n. KnownNat n => Proxy n -> (forall m. KnownNat (m * 3) => Proxy m -> a) -> a
trick p f = unsafeCoerce (MagicNat3 f :: MagicNat3 a) (trickValue (natVal p)) Proxy
test :: forall m. KnownNat (m * 3) => Proxy m -> Integer
test _ = natVal (Proxy :: Proxy (m * 3))
So when you run it:
λ *Main > :t trick (Proxy :: Proxy 4) test :: Integer
trick (Proxy :: Proxy 4) test :: Integer :: Integer
λ *Main > trick (Proxy :: Proxy 4) test :: Integer
12
The trick is based on the fact that in GHC the one member class dictionaries (like KnownNat
) are represented by the member itself. In KnownNat
situation it turns out to be Integer
. So we just unsafeCoerce
it there. Universal quantification makes it sound from the outside.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With