Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Haskell GADTs - making a type-safe Tensor types for Riemannian geometry

Tags:

I want to make a type safe implementation of Tensor calculus in Haskell using GADT's, so the rules are:

  1. Tensors are n-dimentional metrices with indecies that can be 'upstairs' or 'downstairs' eg: enter image description here - is a Tensor with no indecies (a scalar), enter image description here is a Tensor with one 'upstairs' index, enter image description here is a tensor with a bunch of 'upstairs' and 'downstairs' indecies
  2. You can ADD tensor of the same type, meaning they have the same indecies signature. the 0th index of the first tensor is of the same type(upstairs or downstairs) as the 0th index of the second tensor and so on...

    enter image description here ~~~~ OK

    enter image description here ~~~~ NOT OK

  3. You can MULTIPLY tensors and get bigger tensors, with the indecies concatenated: enter image description here

So I want that the type-checker of Haskell wouldn't allow me to write code that doesn't follow those rules, It wouldn't compile otherwise.

Here is my attempt using GADTs:

{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}

data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)

plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)

data Tensor a = (a ~ Zero) => Scalar Double | 
                forall b. (a ~ Up b) => Cov (Direction -> Tensor b) |
                forall b. (a ~ Down b) => Con (Direction -> Tensor b) 

add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))

mul :: Tensor a -> Tensor b -> Tensor (plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))

But i'm getting:

Couldn't match type 'Down with `plus ('Down b1)'                                                                                                                                                                                                    
    Expected type: Tensor (plus a b)                                                                                                                                                                                                                    
      Actual type: Tensor ('Down b)                                                                                                                                                                                                                     
    Relevant bindings include                                                                                                                                                                                                                           
      f :: Direction -> Tensor b1 (bound at main.hs:28:10)                                                                                                                                                                                              
      mul :: Tensor a -> Tensor b -> Tensor (plus a b)                                                                                                                                                                                                  
        (bound at main.hs:24:1)                                                                                                                                                                                                                         
    In the expression: (Con (\ d -> mul (f d) y))                                                                                                                                                                                                       
    In an equation for `mul':                                                                                                                                                                                                                           
        mul (Con f) y = (Con (\ d -> mul (f d) y)) 

What is the problem?

like image 347
user47376 Avatar asked Apr 01 '17 12:04

user47376


1 Answers

plus is just a function on values of type Index

>>> plus Zero Zero
Zero
>>> plus Zero (Up Zero)
Up Zero

so it can't appear in a type signature, as things are. You want to use the 'promoted' type where Zero, Up Zero etc. are types. Then you can write a type function and everything compiles.

{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}

data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)

-- type function Plus
type family Plus (i :: Index) (j :: Index) :: Index where
  Plus Zero x = x
  Plus (Up x) y  = Up (Plus x y)
  Plus (Down x) y = Down (Plus x y)

-- value fuction plus
plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)

data Tensor (a :: Index) where
  Scalar :: Double -> Tensor Zero
  Cov :: (Direction -> Tensor b) -> Tensor (Up b)
  Con :: (Direction -> Tensor b) -> Tensor (Down b)

add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))

mul :: Tensor a -> Tensor b -> Tensor (Plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))

There was no ambiguity in Plus but I could have use the disambiguating tick ' to signal that I was dealing with the type level Zero, Up etc.

type family Plus (i :: Index) (j :: Index) :: Index where
  Plus 'Zero x = x
  Plus ('Up x) y  = 'Up (Plus x y)
  Plus ('Down x) y = 'Down (Plus x y)

TypeOperators would permit you to write a + b rather than Plus a b above.

type family (i :: Index) + (j :: Index) :: Index where
  Zero + x = x
  Up x + y  = Up (x + y)
  Down x + y = Down (x + y) 
like image 60
Michael Avatar answered Sep 25 '22 11:09

Michael