I want to make a type safe implementation of Tensor calculus in Haskell using GADT's, so the rules are:
You can ADD tensor of the same type, meaning they have the same indecies signature. the 0th index of the first tensor is of the same type(upstairs or downstairs) as the 0th index of the second tensor and so on...
~~~~ OK
~~~~ NOT OK
You can MULTIPLY tensors and get bigger tensors, with the indecies concatenated:
So I want that the type-checker of Haskell wouldn't allow me to write code that doesn't follow those rules, It wouldn't compile otherwise.
Here is my attempt using GADTs:
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}
data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)
plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)
data Tensor a = (a ~ Zero) => Scalar Double |
forall b. (a ~ Up b) => Cov (Direction -> Tensor b) |
forall b. (a ~ Down b) => Con (Direction -> Tensor b)
add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))
mul :: Tensor a -> Tensor b -> Tensor (plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))
But i'm getting:
Couldn't match type 'Down with `plus ('Down b1)'
Expected type: Tensor (plus a b)
Actual type: Tensor ('Down b)
Relevant bindings include
f :: Direction -> Tensor b1 (bound at main.hs:28:10)
mul :: Tensor a -> Tensor b -> Tensor (plus a b)
(bound at main.hs:24:1)
In the expression: (Con (\ d -> mul (f d) y))
In an equation for `mul':
mul (Con f) y = (Con (\ d -> mul (f d) y))
What is the problem?
plus
is just a function on values of type Index
>>> plus Zero Zero
Zero
>>> plus Zero (Up Zero)
Up Zero
so it can't appear in a type signature, as things are. You want to use the 'promoted' type where Zero
, Up Zero
etc. are types. Then you can write a type function and everything compiles.
{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)
-- type function Plus
type family Plus (i :: Index) (j :: Index) :: Index where
Plus Zero x = x
Plus (Up x) y = Up (Plus x y)
Plus (Down x) y = Down (Plus x y)
-- value fuction plus
plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)
data Tensor (a :: Index) where
Scalar :: Double -> Tensor Zero
Cov :: (Direction -> Tensor b) -> Tensor (Up b)
Con :: (Direction -> Tensor b) -> Tensor (Down b)
add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))
mul :: Tensor a -> Tensor b -> Tensor (Plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))
There was no ambiguity in Plus
but I could have use the disambiguating tick '
to signal that I was dealing with the type level Zero
, Up
etc.
type family Plus (i :: Index) (j :: Index) :: Index where
Plus 'Zero x = x
Plus ('Up x) y = 'Up (Plus x y)
Plus ('Down x) y = 'Down (Plus x y)
TypeOperators
would permit you to write a + b
rather than Plus a b
above.
type family (i :: Index) + (j :: Index) :: Index where
Zero + x = x
Up x + y = Up (x + y)
Down x + y = Down (x + y)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With