Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Implementing sibling fusion using standard AST

Tags:

haskell

Given a deep embedding of a simple data processing DSL [1]:

{-# LANGUAGE GADTs, StandaloneDeriving #-}

import Data.List
import Text.Show.Functions

data Dist e where

    Concat :: [Dist [a]] -> Dist [a]

    -- We use ConcatMap as a primitive because it can express e.g.
    -- both map and filter.
    ConcatMap :: (a -> [b]) -> Dist [a] -> Dist [b]

    -- Expensive to traverse input (think distributed file).
    Input :: Dist [a]

    Let :: Name -> Dist e -> Dist e -> Dist e

    -- We're not dealing with name collisions here for simplicity.
    Var :: Name -> Dist e

deriving instance Show (Dist e)

type Name = String

we can implement the familiar producer-consumer fusion like so

-- ---------------------------------------------------------------------
-- Producer-consumer fusion

-- Fuses adjacent ConcatMaps.
fuseProducerConsumer :: Dist e -> Dist e
fuseProducerConsumer = go
  where
    go :: Dist e -> Dist e
    go (ConcatMap f (ConcatMap g e)) = ConcatMap (concatMap f . g) (go e)
    go e = e

A little example showing how it works:

-- Should be able to fuse this to a single ConcatMap.
producerConsumerFusable :: Dist [Int]
producerConsumerFusable = ConcatMap (singleton . (+ 1))
                          (ConcatMap (singleton . (* 2)) Input)

singleton :: a -> [a]
singleton = (: [])

-- Expected result after optimization.
expectedProducerConsumerResult =
    ConcatMap (concatMap (singleton . (+ 1)) . (singleton . (* 2))) Input

There's another, much less well known [2], type of fusion called sibling fusion, which removes multiple traversals of the same input. The idea is to replace something like

(map f xs, map g xs)

with

let ys = map (\ x -> (f x, g x)) xs
in (map fst ys, map snd ys)

If traversing ys is much cheaper than traversing xs (e.g. if xs is a file on the network) or if we can e.g. use producer-consumer fusion to later fuse the untagging with some other traversal, this is a win.

While producer-consumer fusion is easily implementable using our standard AST above, I don't see how to implement sibling fusion using this representation.

-- ---------------------------------------------------------------------
-- Sibling fusion

-- Fuses ConcatMaps that consumer the same input.
fuseSibling :: Dist e -> Dist e
fuseSibling = id  -- ???

An example of what we want to happen:

-- The use of Concat below is not important, we just need some Dist e
-- that contains an opportunity for sibling fusion.
siblingFusable :: Dist [Int]
siblingFusable = Let "xs" Input $  -- shares one input
                 Concat [ConcatMap (singleton . (+ 1)) (Var "xs"),
                         ConcatMap (singleton . (* 2)) (Var "xs")]

-- Expected result after optimization.
expectedSiblingResult =
    Let "xs" Input $
    (Let "ys" (ConcatMap
              (mapTwo (singleton . (+ 1)) (singleton . (* 2)))
              (Var "xs"))  -- only one traversal of "xs" and thus Input
     (Concat [ConcatMap lefts  (Var "ys"),
              ConcatMap rights (Var "ys")]))

-- Some helper functions:
lefts :: Either a b -> [a]
lefts (Left x) = [x]
lefts _        = []

rights :: Either a b -> [b]
rights (Right x) = [x]
rights _         = []

mapTwo :: (a -> [b]) -> (a -> [c]) -> a -> [Either b c]
mapTwo f g x = map Left (f x) ++ map Right (g x)

The issue is that while we can easily spot consumer-producer fusion opportunities by pattern matching on ConcatMap ... (ConcatMap ... ...), the two consumers of a single input which give rise to a sibling fusion opportunity aren't necessarily "close" to each other in the AST in the same way.

If we could traverse the AST in the opposite direction i.e. starting from the Inputs, parallel consumers of one input would be much easier to spot. I cannot see how to do this however given that each operation only refers to its input, not its output(s).

Question: Can sibling fusion be implemented using this AST representation or is there some other (e.g. graph or continuation-based) representation that would allow us to implement sibling fusion? Preferably while still using a GADT for type safety.

  1. This DSL is similar to the FlumeJava DSL for distributed computations: http://pages.cs.wisc.edu/~akella/CS838/F12/838-CloudPapers/FlumeJava.pdf
  2. It's probably less well known because it's not clearly a win in single process programs, where additional bookkeeping may outweigh the cost of avoiding retraversing the input. However, if you're input is a 1TB file residing on the network it can be a very big win.
like image 238
tibbe Avatar asked Jul 12 '14 11:07

tibbe


1 Answers

I have created a monster that I will now unleash on the world. Here's an implementation of your transformation in Idris.

I started looking at this in Haskell first, and the problem is that we are essentially looking for a way to collect, for each variable, a set of functions f1 :: a -> b1, f2 :: a -> b2, ... . Coming up with a good representation for this in Haskell is tricky because on one hand, we would like to hide the b1, b2, ... types behind an existential, but on the other hand, when we see a ConcatMap we need to construct a function which extracts the right coordinates from our [Either b1 (Either b2 (...))] at just the right type.

So, first of all, let's make sure our variable references are well-scoped and well-typed, by indexing Dist with the variables in scope and using De Bruijn indexing for the variable occurrences:

%default total

Ctx : Type
Ctx = List Type

data VarPtr : Ctx -> Type -> Type where
  here : VarPtr (a :: ctx) a
  there : VarPtr ctx b -> VarPtr (a :: ctx) b

data Dist : Ctx -> Type -> Type where
  Input : Dist ctx a
  Concat2 : Dist ctx a -> Dist ctx a -> Dist ctx a
  ConcatMap : (a -> List b) -> Dist ctx a -> Dist ctx b
  Let : Dist ctx a -> Dist (a :: ctx) b -> Dist ctx b
  Var : VarPtr ctx a -> Dist ctx a

As can be seen, I've made two simplifications to Dist:

  • Everything is always a list-like thing anyway, so e.g. ConcatMap's type is Dist ctx a -> Dist ctx b instead of Dist ctx (List a) -> Dist ctx (List b). With just the combinators provided in the original question, the only values of Dist one can build are lists anyway. This makes the implementation simpler (in other words, I was running into all kinds of unneeded complications before I made this change).

  • Concat2 is binary instead of n-ary. Changing fuseHoriz below to supprot n-ary concatenation is an exercise left for the reader.

Let's implement vertical fusion first, just to get our feet wet:

fuseVert : Dist ctx a -> Dist ctx a
fuseVert Input = Input
fuseVert (Concat2 xs ys) = Concat2 (fuseVert xs) (fuseVert ys)
fuseVert (ConcatMap f d) = case fuseVert d of
    ConcatMap g d' => ConcatMap (concatMap f . g) d'
    d' => ConcatMap f d'
fuseVert (Let d0 d) = Let (fuseVert d0) (fuseVert d)
fuseVert (Var k) = Var k

So far so good:

namespace Examples
  f : Int -> List Int
  f = return . (+1)

  g : Int -> List Int
  g = return . (* 2)

  ex1 : Dist [] Int
  ex1 = ConcatMap f $ ConcatMap g $ Input

  ex1' : Dist [] Int
  ex1' = ConcatMap (concatMap f . g) $ Input

  prf : fuseVert ex1 = ex1'
  prf = Refl

Now for the fun part. We need a good representation of "collection of functions from the same domain" and a way to point at a particular function (with a particular codomain) in that collection. We will be collecting these functions from ConcatMap f (Var v) calls, keyed by v; and then replace the call itself with a hole that will be filled in once we finished collecting everything.

When we encounter Concat2 d1 d2, we will need to merge the functions collected from both sides, and then weaken the holes in d1 and d2 to be over this extended collection. I am using a binary tree instead of a flat list for this reason: so that the weakening is easy to implement.

It goes in its own namespace since I am reusing the here/there terminology:

namespace Funs
  data Funs : Type -> Type where
    None : Funs a
    Leaf : (a -> List b) -> Funs a
    Branch : Funs a -> Funs a -> Funs a

  instance Semigroup (Funs a) where
    (<+>) = Branch

  data FunPtr : Funs a -> Type -> Type where
    here : FunPtr (Leaf {b} _) b
    left : FunPtr fs b -> FunPtr (Branch fs _) b
    right : FunPtr fs b -> FunPtr (Branch _ fs) b

Now that we have a representation for the collection of all functions applied on a given variable, we can finally make some progress towards implementing horizontal fusion.

To reiterate, the goal is to turn something like

let xs = Input :: [A]
in Concat2 (E $ ConcatMap f xs) (F $ ConcatMap g xs)
where
  f :: A -> [B]
  g :: A -> [C]

into something like

let xs = Input :: [A]
    xs' = ConcatMap (\x -> map Left (f x) ++ map Right (g x)) xs :: [(Either B C)]
in Concat2 (E $ ConcatMap (either return (const []) xs') (F $ ConcatMap (either (const []) return) xs')

So first of all, we need to be able to code-gen the memoizer (the definition of xs') from the collection of functions applied on xs:

  memoType : Funs a -> Type
  memoType None = ()
  memoType (Leaf {b} _) = b
  memoType (Branch fs1 fs2) = Either (memoType fs1) (memoType fs2)

  memoFun : (fs : Funs a) -> (a -> List (memoType fs))
  memoFun None = const []
  memoFun (Leaf f) = f
  memoFun (Branch fs1 fs2) = (\xs => map Left (memoFun fs1 xs) <+> map Right (memoFun fs2 xs))

  memoExpr : (fs : Funs a) -> Dist (a :: ctx) (memoType fs)
  memoExpr fs = ConcatMap (memoFun fs) (Var here)

It won't be much use if we can't look up these memoized results later on:

  lookupMemo : {fs : Funs a} -> (i : FunPtr fs b) -> (memoType fs -> List b)
  lookupMemo {fs = Leaf f} here = \x => [x]
  lookupMemo {fs = (Branch fs1 fs2)} (left i) = either (lookupMemo i) (const [])
  lookupMemo {fs = (Branch fs1 fs2)} (right i) = either (const []) (lookupMemo i)

Now, as we traverse the source tree, we of course collect usages (via ConcatMap) of several variables at the same time, since it is entirely possible to have something like

let xs = ...
in Concat2 (ConcatMap f xs) (let ys = ... in ... (ConcatMap g xs) ...)

This will be populated in lockstep with the variable context, since in every Let binding, we can also generate the memoizer of all the usages of the new variable.

namespace Usages
  data Usages : Ctx -> Type where
    Nil : Usages []
    (::) : {a : Type} -> Funs a -> Usages ctx -> Usages (a :: ctx)

  unused : {ctx : Ctx} -> Usages ctx
  unused {ctx = []} = []
  unused {ctx = _ :: ctx} = None :: unused {ctx}

  instance Semigroup (Usages ctx) where
    [] <+> [] = []
    (fs1 :: us1) <+> (fs2 :: us2) = (fs1 <+> fs2) :: (us1 <+> us2)

We will be reserving space for these synthetic variables:

  ctxDup : {ctx : Ctx} -> Usages ctx -> Ctx
  ctxDup {ctx = []} us = []
  ctxDup {ctx = t :: ts} (fs :: us) = (memoType fs) :: t :: ctxDup us

  varDup : {us : Usages ctx} -> VarPtr ctx a -> VarPtr (ctxDup us) a
  varDup {us = _ :: _} here = there here
  varDup {us = _ :: _} (there v) = there $ there $ varDup v

Now we are finally ready to define our optimizer's internal intermediate representation: "Dist with holes". Each hole stands for an application of a function on a variable, which will be filled in when we know all the usages and we have all the synthetic variables for them in scope:

namespace HDist
  data Hole : Usages ctx -> Type -> Type where
    here : FunPtr u b -> Hole (u :: us) b
    there : Hole us b -> Hole (_ :: us) b

  resolve : {us : Usages ctx} -> Hole us b -> Exists (\a => (VarPtr (ctxDup us) a, a -> List b))
  resolve (here i) = Evidence _ (here, lookupMemo i)
  resolve (there h) with (resolve h) | Evidence a (v, f) = Evidence a (there $ there v, f)

  data HDist : Usages ctx -> Type -> Type where
    HInput : HDist us a
    HConcat : HDist us a -> HDist us a -> HDist us a
    HConcatMap : (b -> List a) -> HDist us b -> HDist us a
    HLet : HDist us a -> (fs : Funs a) -> HDist (fs :: us) b -> HDist us b
    HVar : {ctx : Ctx} -> {us : Usages ctx} -> VarPtr ctx a -> HDist us a
    HHole : (hole : Hole us a) -> HDist us a

So once we have such a holey Dist, filling it in is just a matter of walking it and resolving the holes:

fill : HDist us a -> Dist (ctxDup us) a
fill HInput = Input
fill (HConcat e1 e2) = Concat2 (fill e1) (fill e2)
fill (HConcatMap f e) = ConcatMap f $ fill e
fill (HLet e0 fs e) = Let (fill e0) $ Let (memoExpr fs) $ fill e
fill (HVar x) = Var (varDup x)
fill (HHole h) with (resolve h) | Evidence a (v, f) = ConcatMap f $ Var v

Horizontal fusion, then, is just a matter of elbow grease: turning a Dist ctx a into a HDist us a such that every ConcatMap f (Var v) is turned into an HHole. We need to do some extra funny dance to shift holes around when combining two Usages from the two sides of a Concat2.

weakenHoleL : Hole us1 a -> Hole (us1 <+> us2) a
weakenHoleL {us1 = _ :: _} {us2 = _ :: _} (here i) = here (left i)
weakenHoleL {us1 = _ :: _} {us2 = _ :: _} (there h) = there $ weakenHoleL h

weakenHoleR : Hole us2 a -> Hole (us1 <+> us2) a
weakenHoleR {us1 = _ :: _} {us2 = _ :: _} (here i) = here (right i)
weakenHoleR {us1 = _ :: _} {us2 = _ :: _} (there h) = there $ weakenHoleR h

weakenL : HDist us1 a -> HDist (us1 <+> us2) a
weakenL HInput = HInput
weakenL (HConcat e1 e2) = HConcat (weakenL e1) (weakenL e2)
weakenL (HConcatMap f e) = HConcatMap f (weakenL e)
weakenL {us1 = us1} {us2 = us2} (HLet e fs x) = HLet (weakenL e) (Branch fs None) (weakenL {us2 = None :: us2} x)
weakenL (HVar x) = HVar x
weakenL (HHole hole) = HHole (weakenHoleL hole)

weakenR : HDist us2 a -> HDist (us1 <+> us2) a
weakenR HInput = HInput
weakenR (HConcat e1 e2) = HConcat (weakenR e1) (weakenR e2)
weakenR (HConcatMap f e) = HConcatMap f (weakenR e)
weakenR {us1 = us1} {us2 = us2} (HLet e fs x) = HLet (weakenR e) (Branch None fs) (weakenR {us1 = None :: us1} x)
weakenR (HVar x) = HVar x
weakenR (HHole hole) = HHole (weakenHoleR hole)

fuseHoriz : Dist ctx a -> Exists {a = Usages ctx} (\us => HDist us a)
fuseHoriz Input = Evidence unused HInput
fuseHoriz (Concat2 d1 d2) with (fuseHoriz d1)
  | Evidence us1 e1 with (fuseHoriz d2)
    | Evidence us2 e2 =
        Evidence (us1 <+> us2) $ HConcat (weakenL e1) (weakenR e2)
fuseHoriz {ctx = _ :: ctx} (ConcatMap f (Var here)) =
    Evidence (Leaf f :: unused) (HHole (here here))
fuseHoriz (ConcatMap f d) with (fuseHoriz d)
  | Evidence us e = Evidence us (HConcatMap f e)
fuseHoriz (Let d0 d) with (fuseHoriz d0)
  | Evidence us0 e0 with (fuseHoriz d)
    | Evidence (fs :: us) e =
        Evidence (us0 <+> us) $ HLet (weakenL e0) (Branch None fs) $ weakenR {us1 = None :: us0} e
fuseHoriz (Var v) = Evidence unused (HVar v)

We can use this monstrosity by combining it with fuseVert and feeding it to fill:

fuse : Dist [] a -> Dist [] a
fuse d = fill $ getProof $ fuseHoriz . fuseVert $ d

And presto:

namespace Examples
  ex2 : Dist [] Int
  ex2 = Let Input $
        Concat2 (ConcatMap f (Var here))
                (ConcatMap g (Var here))

  ex2' : Dist [] Int
  ex2' = Let Input $
         Let (ConcatMap (\x => map Left [] ++ map Right (map Left (f x) ++ map Right (g x))) (Var here)) $
         Concat2 (ConcatMap f' (Var here)) (ConcatMap g' (Var here))
    where
      f' : Either () (Either Int Int) -> List Int
      f' = either (const []) $ either return $ const []

      g' : Either () (Either Int Int) -> List Int
      g' = either (const []) $ either (const []) $ return

  prf2 : fuse ex2 = ex2'
  prf2 = Refl

Addendum

I wish I could have fused fuseVert into fuseHoriz, since I think all it should require is an extra case:

fuseHoriz (ConcatMap f (ConcatMap g d)) = fuseHoriz (ConcatMap (concatMap f . g) d)

However, this confused the Idris termination checker unless I add an assert_smaller on ConcatMap (concatMap f . g) d vs ConcatMap f (ConcatMap g d)) which I don't understand why, since one has one more layer of ConcatMap constructors than the other.

like image 140
Cactus Avatar answered Nov 15 '22 07:11

Cactus