Given a deep embedding of a simple data processing DSL [1]:
{-# LANGUAGE GADTs, StandaloneDeriving #-}
import Data.List
import Text.Show.Functions
data Dist e where
Concat :: [Dist [a]] -> Dist [a]
-- We use ConcatMap as a primitive because it can express e.g.
-- both map and filter.
ConcatMap :: (a -> [b]) -> Dist [a] -> Dist [b]
-- Expensive to traverse input (think distributed file).
Input :: Dist [a]
Let :: Name -> Dist e -> Dist e -> Dist e
-- We're not dealing with name collisions here for simplicity.
Var :: Name -> Dist e
deriving instance Show (Dist e)
type Name = String
we can implement the familiar producer-consumer fusion like so
-- ---------------------------------------------------------------------
-- Producer-consumer fusion
-- Fuses adjacent ConcatMaps.
fuseProducerConsumer :: Dist e -> Dist e
fuseProducerConsumer = go
where
go :: Dist e -> Dist e
go (ConcatMap f (ConcatMap g e)) = ConcatMap (concatMap f . g) (go e)
go e = e
A little example showing how it works:
-- Should be able to fuse this to a single ConcatMap.
producerConsumerFusable :: Dist [Int]
producerConsumerFusable = ConcatMap (singleton . (+ 1))
(ConcatMap (singleton . (* 2)) Input)
singleton :: a -> [a]
singleton = (: [])
-- Expected result after optimization.
expectedProducerConsumerResult =
ConcatMap (concatMap (singleton . (+ 1)) . (singleton . (* 2))) Input
There's another, much less well known [2], type of fusion called sibling fusion, which removes multiple traversals of the same input. The idea is to replace something like
(map f xs, map g xs)
with
let ys = map (\ x -> (f x, g x)) xs
in (map fst ys, map snd ys)
If traversing ys
is much cheaper than traversing xs
(e.g. if xs
is a file on the network) or if we can e.g. use producer-consumer fusion to later fuse the untagging with some other traversal, this is a win.
While producer-consumer fusion is easily implementable using our standard AST above, I don't see how to implement sibling fusion using this representation.
-- ---------------------------------------------------------------------
-- Sibling fusion
-- Fuses ConcatMaps that consumer the same input.
fuseSibling :: Dist e -> Dist e
fuseSibling = id -- ???
An example of what we want to happen:
-- The use of Concat below is not important, we just need some Dist e
-- that contains an opportunity for sibling fusion.
siblingFusable :: Dist [Int]
siblingFusable = Let "xs" Input $ -- shares one input
Concat [ConcatMap (singleton . (+ 1)) (Var "xs"),
ConcatMap (singleton . (* 2)) (Var "xs")]
-- Expected result after optimization.
expectedSiblingResult =
Let "xs" Input $
(Let "ys" (ConcatMap
(mapTwo (singleton . (+ 1)) (singleton . (* 2)))
(Var "xs")) -- only one traversal of "xs" and thus Input
(Concat [ConcatMap lefts (Var "ys"),
ConcatMap rights (Var "ys")]))
-- Some helper functions:
lefts :: Either a b -> [a]
lefts (Left x) = [x]
lefts _ = []
rights :: Either a b -> [b]
rights (Right x) = [x]
rights _ = []
mapTwo :: (a -> [b]) -> (a -> [c]) -> a -> [Either b c]
mapTwo f g x = map Left (f x) ++ map Right (g x)
The issue is that while we can easily spot consumer-producer fusion opportunities by pattern matching on ConcatMap ... (ConcatMap ... ...)
, the two consumers of a single input which give rise to a sibling fusion opportunity aren't necessarily "close" to each other in the AST in the same way.
If we could traverse the AST in the opposite direction i.e. starting from the Input
s, parallel consumers of one input would be much easier to spot. I cannot see how to do this however given that each operation only refers to its input, not its output(s).
Question: Can sibling fusion be implemented using this AST representation or is there some other (e.g. graph or continuation-based) representation that would allow us to implement sibling fusion? Preferably while still using a GADT for type safety.
I have created a monster that I will now unleash on the world. Here's an implementation of your transformation in Idris.
I started looking at this in Haskell first, and the problem is that we are essentially looking for a way to collect, for each variable, a set of functions f1 :: a -> b1, f2 :: a -> b2, ...
. Coming up with a good representation for this in Haskell is tricky because on one hand, we would like to hide the b1, b2, ...
types behind an existential, but on the other hand, when we see a ConcatMap
we need to construct a function which extracts the right coordinates from our [Either b1 (Either b2 (...))]
at just the right type.
So, first of all, let's make sure our variable references are well-scoped and well-typed, by indexing Dist
with the variables in scope and using De Bruijn indexing for the variable occurrences:
%default total
Ctx : Type
Ctx = List Type
data VarPtr : Ctx -> Type -> Type where
here : VarPtr (a :: ctx) a
there : VarPtr ctx b -> VarPtr (a :: ctx) b
data Dist : Ctx -> Type -> Type where
Input : Dist ctx a
Concat2 : Dist ctx a -> Dist ctx a -> Dist ctx a
ConcatMap : (a -> List b) -> Dist ctx a -> Dist ctx b
Let : Dist ctx a -> Dist (a :: ctx) b -> Dist ctx b
Var : VarPtr ctx a -> Dist ctx a
As can be seen, I've made two simplifications to Dist
:
Everything is always a list-like thing anyway, so e.g. ConcatMap
's type is Dist ctx a -> Dist ctx b
instead of Dist ctx (List a) -> Dist ctx (List b)
. With just the combinators provided in the original question, the only values of Dist
one can build are lists anyway. This makes the implementation simpler (in other words, I was running into all kinds of unneeded complications before I made this change).
Concat2
is binary instead of n-ary. Changing fuseHoriz
below to supprot n-ary concatenation is an exercise left for the reader.
Let's implement vertical fusion first, just to get our feet wet:
fuseVert : Dist ctx a -> Dist ctx a
fuseVert Input = Input
fuseVert (Concat2 xs ys) = Concat2 (fuseVert xs) (fuseVert ys)
fuseVert (ConcatMap f d) = case fuseVert d of
ConcatMap g d' => ConcatMap (concatMap f . g) d'
d' => ConcatMap f d'
fuseVert (Let d0 d) = Let (fuseVert d0) (fuseVert d)
fuseVert (Var k) = Var k
So far so good:
namespace Examples
f : Int -> List Int
f = return . (+1)
g : Int -> List Int
g = return . (* 2)
ex1 : Dist [] Int
ex1 = ConcatMap f $ ConcatMap g $ Input
ex1' : Dist [] Int
ex1' = ConcatMap (concatMap f . g) $ Input
prf : fuseVert ex1 = ex1'
prf = Refl
Now for the fun part. We need a good representation of "collection of functions from the same domain" and a way to point at a particular function (with a particular codomain) in that collection. We will be collecting these functions from ConcatMap f (Var v)
calls, keyed by v
; and then replace the call itself with a hole that will be filled in once we finished collecting everything.
When we encounter Concat2 d1 d2
, we will need to merge the functions collected from both sides, and then weaken the holes in d1
and d2
to be over this extended collection.
I am using a binary tree instead of a flat list for this reason: so that the weakening is easy to implement.
It goes in its own namespace since I am reusing the here
/there
terminology:
namespace Funs
data Funs : Type -> Type where
None : Funs a
Leaf : (a -> List b) -> Funs a
Branch : Funs a -> Funs a -> Funs a
instance Semigroup (Funs a) where
(<+>) = Branch
data FunPtr : Funs a -> Type -> Type where
here : FunPtr (Leaf {b} _) b
left : FunPtr fs b -> FunPtr (Branch fs _) b
right : FunPtr fs b -> FunPtr (Branch _ fs) b
Now that we have a representation for the collection of all functions applied on a given variable, we can finally make some progress towards implementing horizontal fusion.
To reiterate, the goal is to turn something like
let xs = Input :: [A]
in Concat2 (E $ ConcatMap f xs) (F $ ConcatMap g xs)
where
f :: A -> [B]
g :: A -> [C]
into something like
let xs = Input :: [A]
xs' = ConcatMap (\x -> map Left (f x) ++ map Right (g x)) xs :: [(Either B C)]
in Concat2 (E $ ConcatMap (either return (const []) xs') (F $ ConcatMap (either (const []) return) xs')
So first of all, we need to be able to code-gen the memoizer (the definition of xs'
) from the collection of functions applied on xs
:
memoType : Funs a -> Type
memoType None = ()
memoType (Leaf {b} _) = b
memoType (Branch fs1 fs2) = Either (memoType fs1) (memoType fs2)
memoFun : (fs : Funs a) -> (a -> List (memoType fs))
memoFun None = const []
memoFun (Leaf f) = f
memoFun (Branch fs1 fs2) = (\xs => map Left (memoFun fs1 xs) <+> map Right (memoFun fs2 xs))
memoExpr : (fs : Funs a) -> Dist (a :: ctx) (memoType fs)
memoExpr fs = ConcatMap (memoFun fs) (Var here)
It won't be much use if we can't look up these memoized results later on:
lookupMemo : {fs : Funs a} -> (i : FunPtr fs b) -> (memoType fs -> List b)
lookupMemo {fs = Leaf f} here = \x => [x]
lookupMemo {fs = (Branch fs1 fs2)} (left i) = either (lookupMemo i) (const [])
lookupMemo {fs = (Branch fs1 fs2)} (right i) = either (const []) (lookupMemo i)
Now, as we traverse the source tree, we of course collect usages (via ConcatMap
) of several variables at the same time, since it is entirely possible to have something like
let xs = ...
in Concat2 (ConcatMap f xs) (let ys = ... in ... (ConcatMap g xs) ...)
This will be populated in lockstep with the variable context, since in every Let
binding, we can also generate the memoizer of all the usages of the new variable.
namespace Usages
data Usages : Ctx -> Type where
Nil : Usages []
(::) : {a : Type} -> Funs a -> Usages ctx -> Usages (a :: ctx)
unused : {ctx : Ctx} -> Usages ctx
unused {ctx = []} = []
unused {ctx = _ :: ctx} = None :: unused {ctx}
instance Semigroup (Usages ctx) where
[] <+> [] = []
(fs1 :: us1) <+> (fs2 :: us2) = (fs1 <+> fs2) :: (us1 <+> us2)
We will be reserving space for these synthetic variables:
ctxDup : {ctx : Ctx} -> Usages ctx -> Ctx
ctxDup {ctx = []} us = []
ctxDup {ctx = t :: ts} (fs :: us) = (memoType fs) :: t :: ctxDup us
varDup : {us : Usages ctx} -> VarPtr ctx a -> VarPtr (ctxDup us) a
varDup {us = _ :: _} here = there here
varDup {us = _ :: _} (there v) = there $ there $ varDup v
Now we are finally ready to define our optimizer's internal intermediate representation: "Dist
with holes". Each hole stands for an application of a function on a variable, which will be filled in when we know all the usages and we have all the synthetic variables for them in scope:
namespace HDist
data Hole : Usages ctx -> Type -> Type where
here : FunPtr u b -> Hole (u :: us) b
there : Hole us b -> Hole (_ :: us) b
resolve : {us : Usages ctx} -> Hole us b -> Exists (\a => (VarPtr (ctxDup us) a, a -> List b))
resolve (here i) = Evidence _ (here, lookupMemo i)
resolve (there h) with (resolve h) | Evidence a (v, f) = Evidence a (there $ there v, f)
data HDist : Usages ctx -> Type -> Type where
HInput : HDist us a
HConcat : HDist us a -> HDist us a -> HDist us a
HConcatMap : (b -> List a) -> HDist us b -> HDist us a
HLet : HDist us a -> (fs : Funs a) -> HDist (fs :: us) b -> HDist us b
HVar : {ctx : Ctx} -> {us : Usages ctx} -> VarPtr ctx a -> HDist us a
HHole : (hole : Hole us a) -> HDist us a
So once we have such a holey Dist
, filling it in is just a matter of walking it and resolving the holes:
fill : HDist us a -> Dist (ctxDup us) a
fill HInput = Input
fill (HConcat e1 e2) = Concat2 (fill e1) (fill e2)
fill (HConcatMap f e) = ConcatMap f $ fill e
fill (HLet e0 fs e) = Let (fill e0) $ Let (memoExpr fs) $ fill e
fill (HVar x) = Var (varDup x)
fill (HHole h) with (resolve h) | Evidence a (v, f) = ConcatMap f $ Var v
Horizontal fusion, then, is just a matter of elbow grease: turning a Dist ctx a
into a HDist us a
such that every ConcatMap f (Var v)
is turned into an HHole
. We need to do some extra funny dance to shift holes around when combining two Usages
from the two sides of a Concat2
.
weakenHoleL : Hole us1 a -> Hole (us1 <+> us2) a
weakenHoleL {us1 = _ :: _} {us2 = _ :: _} (here i) = here (left i)
weakenHoleL {us1 = _ :: _} {us2 = _ :: _} (there h) = there $ weakenHoleL h
weakenHoleR : Hole us2 a -> Hole (us1 <+> us2) a
weakenHoleR {us1 = _ :: _} {us2 = _ :: _} (here i) = here (right i)
weakenHoleR {us1 = _ :: _} {us2 = _ :: _} (there h) = there $ weakenHoleR h
weakenL : HDist us1 a -> HDist (us1 <+> us2) a
weakenL HInput = HInput
weakenL (HConcat e1 e2) = HConcat (weakenL e1) (weakenL e2)
weakenL (HConcatMap f e) = HConcatMap f (weakenL e)
weakenL {us1 = us1} {us2 = us2} (HLet e fs x) = HLet (weakenL e) (Branch fs None) (weakenL {us2 = None :: us2} x)
weakenL (HVar x) = HVar x
weakenL (HHole hole) = HHole (weakenHoleL hole)
weakenR : HDist us2 a -> HDist (us1 <+> us2) a
weakenR HInput = HInput
weakenR (HConcat e1 e2) = HConcat (weakenR e1) (weakenR e2)
weakenR (HConcatMap f e) = HConcatMap f (weakenR e)
weakenR {us1 = us1} {us2 = us2} (HLet e fs x) = HLet (weakenR e) (Branch None fs) (weakenR {us1 = None :: us1} x)
weakenR (HVar x) = HVar x
weakenR (HHole hole) = HHole (weakenHoleR hole)
fuseHoriz : Dist ctx a -> Exists {a = Usages ctx} (\us => HDist us a)
fuseHoriz Input = Evidence unused HInput
fuseHoriz (Concat2 d1 d2) with (fuseHoriz d1)
| Evidence us1 e1 with (fuseHoriz d2)
| Evidence us2 e2 =
Evidence (us1 <+> us2) $ HConcat (weakenL e1) (weakenR e2)
fuseHoriz {ctx = _ :: ctx} (ConcatMap f (Var here)) =
Evidence (Leaf f :: unused) (HHole (here here))
fuseHoriz (ConcatMap f d) with (fuseHoriz d)
| Evidence us e = Evidence us (HConcatMap f e)
fuseHoriz (Let d0 d) with (fuseHoriz d0)
| Evidence us0 e0 with (fuseHoriz d)
| Evidence (fs :: us) e =
Evidence (us0 <+> us) $ HLet (weakenL e0) (Branch None fs) $ weakenR {us1 = None :: us0} e
fuseHoriz (Var v) = Evidence unused (HVar v)
We can use this monstrosity by combining it with fuseVert
and feeding it to fill
:
fuse : Dist [] a -> Dist [] a
fuse d = fill $ getProof $ fuseHoriz . fuseVert $ d
And presto:
namespace Examples
ex2 : Dist [] Int
ex2 = Let Input $
Concat2 (ConcatMap f (Var here))
(ConcatMap g (Var here))
ex2' : Dist [] Int
ex2' = Let Input $
Let (ConcatMap (\x => map Left [] ++ map Right (map Left (f x) ++ map Right (g x))) (Var here)) $
Concat2 (ConcatMap f' (Var here)) (ConcatMap g' (Var here))
where
f' : Either () (Either Int Int) -> List Int
f' = either (const []) $ either return $ const []
g' : Either () (Either Int Int) -> List Int
g' = either (const []) $ either (const []) $ return
prf2 : fuse ex2 = ex2'
prf2 = Refl
I wish I could have fused fuseVert
into fuseHoriz
, since I think all it should require is an extra case:
fuseHoriz (ConcatMap f (ConcatMap g d)) = fuseHoriz (ConcatMap (concatMap f . g) d)
However, this confused the Idris termination checker unless I add an assert_smaller
on ConcatMap (concatMap f . g) d
vs ConcatMap f (ConcatMap g d))
which I don't understand why, since one has one more layer of ConcatMap
constructors than the other.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With