Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Elegant implementation of n-dimensional matrix multiplication using lists?

List functions allow us to implement arbitrarily-dimensional vector math quite elegantly. For example:

on   = (.) . (.)
add  = zipWith (+)
sub  = zipWith (-)
mul  = zipWith (*)
dist = len `on` sub
dot  = sum `on` mul
len  = sqrt . join dot

And so on.

main = print $ add [1,2,3] [1,1,1] -- [2,3,4]
main = print $ len [1,1,1]         -- 1.7320508075688772
main = print $ dot [2,0,0] [2,0,0] -- 4

Of course, this is not the most efficient solution, but is insightful to look at, as one can say map, zipWith and such generalize those vector operations. There is one function I couldn't implement elegantly, though - that is cross products. Since a possible n-dimensional generalization of cross products is the nd matrix determinant, how can I implement matrix multiplication elegantly?

Edit: yes, I asked a completely unrelated question to the problem I set up. Fml.

like image 988
MaiaVictor Avatar asked Jul 02 '15 22:07

MaiaVictor


1 Answers

It just so happens I have some code lying around for doing n-dimensional matrix operations which I thought was quite cute when I wrote it at least:

{-# LANGUAGE NoMonomorphismRestriction #-}
module MultiArray where

import Control.Arrow
import Control.Monad
import Data.Ix
import Data.Maybe

import Data.Array (Array)
import qualified Data.Array as A

-- {{{ from Dmwit.hs
deleteAt n   xs = take n xs ++ drop (n + 1) xs
insertAt n x xs = take n xs ++ x : drop n xs

doublify f g xs ys = f (uncurry g) (zip xs ys)
any2 = doublify any
all2 = doublify all
-- }}}

-- makes the most sense when ls and hs have the same length
instance Ix a => Ix [a] where
    range     = sequence . map range . uncurry zip
    inRange   = all2 inRange . uncurry zip
    rangeSize = product . uncurry (zipWith (curry rangeSize))

    index (ls, hs) xs = fst . foldr step (0, 1) $ zip indices sizes where
        indices = zipWith index (zip ls hs) xs
        sizes   = map rangeSize $ zip ls hs
        step (i, b) (s, p) = (s + p * i, p * b)

fold :: (Enum i, Ix i) => ([a] -> b) -> Int -> Array [i] a -> Array [i] b
fold f n a = A.array newBound assocs where
    (oldLowBound, oldHighBound) = A.bounds a
    (newLowBoundBeg , dimLow : newLowBoundEnd ) = splitAt n oldLowBound
    (newHighBoundBeg, dimHigh: newHighBoundEnd) = splitAt n oldHighBound
    assocs   = [(beg ++ end, f [a A.! (beg ++ i : end) | i <- [dimLow..dimHigh]])
               | beg <- range (newLowBoundBeg, newHighBoundBeg)
               , end <- range (newLowBoundEnd, newHighBoundEnd)
               ]
    newBound = (newLowBoundBeg ++ newLowBoundEnd, newHighBoundBeg ++ newHighBoundEnd)

flatten a = check a >> return value where
    check = guard . (1==) . length . fst . A.bounds
    value = A.ixmap ((head *** head) . A.bounds $ a) return a

elementWise :: (MonadPlus m, Ix i) => (a -> b -> c) -> Array i a -> Array i b -> m (Array i c)
elementWise f a b = check >> return value where
    check = guard $ A.bounds a == A.bounds b
    value = A.listArray (A.bounds a) (zipWith f (A.elems a) (A.elems b))

unsafeFlatten       a   = fromJust $ flatten       a
unsafeElementWise f a b = fromJust $ elementWise f a b

matrixMult a b = fold sum 1 $ unsafeElementWise (*) a' b' where
    aBounds = (join (***) (!!0)) $ A.bounds a
    bBounds = (join (***) (!!1)) $ A.bounds b
    a' = copy 2 bBounds a
    b' = copy 0 aBounds b

bijection f g a = A.ixmap ((f *** f) . A.bounds $ a) g a
unFlatten       = bijection return head
matrixTranspose = bijection reverse reverse
copy n (low, high) a = A.ixmap (newBounds a) (deleteAt n) a where
    newBounds = (insertAt n low *** insertAt n high) . A.bounds

The cute bit here is matrixMult, which is one of the only operations that is specialized to two-dimensional arrays. It expands its first argument along one dimension (by putting a copy of the two-dimensional object into each slice of the three-dimensional object); expands its second along another; does pointwise multiplication (now in a three-dimensional array); then collapses the fabricated third dimension by summing. Quite nice.

like image 53
Daniel Wagner Avatar answered Nov 15 '22 06:11

Daniel Wagner