Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fusion for length-indexed linked lists

I have the standard Vect type:

module Vect where

data Nat = Z | S Nat
data Vect (n :: Nat) (a :: Type) where
  VNil :: Vect Z a
  VCons :: a -> Vect n a -> Vect (S n) a

And I have these functions on it:

foldVect :: forall (p :: Nat -> Type) n a.
            p Z ->
            (forall m. a -> p m -> p (S m)) ->
            Vect n a -> p n
foldVect n c = go
  where go :: forall l. Vect l a -> p l
        go VNil = n
        go (VCons x xs) = x `c` go xs

newtype FVect a n = FVect { unFVect :: Vect n a }
buildVect :: forall n a.
             (forall (p :: Nat -> Type).
              p Z ->
              (forall m. a -> p m -> p (S m)) ->
              p n
             ) -> Vect n a
buildVect f = unFVect $ f (FVect VNil) $ \x (FVect xs) -> FVect $ x `VCons` xs

I attempted to recreate (part of) the machinery from base that allows for list fusion:

instance Functor (Vect n) where
    fmap = mapVect
    {-# INLINE fmap #-}

mapVect :: forall n a b. (a -> b) -> (Vect n a -> Vect n b)
mapVect _f VNil = VNil
mapVect f (VCons x xs) = f x `VCons` mapVect f xs
mapFB :: forall (p :: Nat -> Type) n a b. (forall m. b -> p m -> p (S m)) -> (a -> b) -> a -> p n -> p (S n)
mapFB cons f = \x ys -> cons (f x) ys
{-# INLINE [0] mapFB #-}
{-# NOINLINE [0] mapVect #-}
{-# RULES "mapVect" [~1] forall f xs. mapVect f xs = buildVect (\nil cons -> foldVect nil (mapFB cons f) xs) #-}

{-# INLINE [0] foldVect #-}
-- base has this; I don't think it does anything without a "refolding" rule on mapVect
{-# INLINE [0] buildVect #-}
{-# RULES "foldVect/buildVect" forall (nil :: p Z)
                                      (cons :: forall m. a -> p m -> p (S m))
                                      (f :: forall (q :: Nat -> Type).
                                            q Z ->
                                            (forall m. a -> q m -> q (S m)) ->
                                            q n
                                      ).
          foldVect nil cons (buildVect f) = f nil cons
  #-}

And then

module Test where
import Vect

test :: Vect n Int -> Vect n Int
test = fmap (*5) . fmap (+2)

No fusion happens. If I pass -ddump-simpl, I see that foldVect and buildVect have already been inlined, but...

  1. The INLINEs are phased so that they don't interfere with fusion anyway. (After all, this is how base does it for [])
  2. Replacing the INLINE [0]s with NOINLINE paints a rather stunning image:

    test
      = \ (@ (n_a141 :: Nat)) (x_X1lK :: Vect n_a141 Int) ->
          buildVect
            @ n_a141
            @ Int
            (\ (@ (p_X1jl :: Nat -> *))
               (nil_X11K [OS=OneShot] :: p_X1jl 'Z)
               (cons_X11M [OS=OneShot]
                  :: forall (m :: Nat). Int -> p_X1jl m -> p_X1jl ('S m)) ->
               foldVect
                 @ p_X1jl
                 @ n_a141
                 @ Int
                 nil_X11K
                 (\ (@ (m_a1i5 :: Nat))
                    (x1_aYI :: Int)
                    (ys_aYJ [OS=OneShot] :: p_X1jl m_a1i5) ->
                    cons_X11M
                      @ m_a1i5
                      (case x1_aYI of { GHC.Types.I# x2_a1l5 ->
                       GHC.Types.I# (GHC.Prim.*# x2_a1l5 5#)
                       })
                      ys_aYJ)
                 (buildVect
                    @ n_a141
                    @ Int
                    (\ (@ (p1_a1i0 :: Nat -> *))
                       (nil1_a10o [OS=OneShot] :: p1_a1i0 'Z)
                       (cons1_a10p [OS=OneShot]
                          :: forall (m :: Nat). Int -> p1_a1i0 m -> p1_a1i0 ('S m)) ->
                       foldVect
                         @ p1_a1i0
                         @ n_a141
                         @ Int
                         nil1_a10o
                         (\ (@ (m_a1i5 :: Nat))
                            (x1_aYI :: Int)
                            (ys_aYJ [OS=OneShot] :: p1_a1i0 m_a1i5) ->
                            cons1_a10p
                              @ m_a1i5
                              (case x1_aYI of { GHC.Types.I# x2_a1lh ->
                               GHC.Types.I# (GHC.Prim.+# x2_a1lh 2#)
                               })
                              ys_aYJ)
                         x_X1lK)))
    

    Everything is right there, but the simplifier is just not having it.

If I inspect the rule itself, I see this

"foldVect/buildVect"
    forall (@ (p_aYG :: Nat -> *))
           (@ (n_aYJ :: Nat))
           (@ a_aYH)
           (nil_aYD :: p_aYG 'Z)
           (cons_aYE :: forall (m :: Nat). a_aYH -> p_aYG m -> p_aYG ('S m))
           (f_aYF
              :: forall (q :: Nat -> *).
                 q 'Z -> (forall (m :: Nat). a_aYH -> q m -> q ('S m)) -> q n_aYJ).
      foldVect @ p_aYG
               @ n_aYJ
               @ a_aYH
               nil_aYD
               cons_aYE
               (buildVect
                  @ n_aYJ
                  @ a_aYH
                  (\ (@ (p1_a156 :: Nat -> *))
                     (ds_d1io :: p1_a156 'Z)
                     (ds1_d1ip
                        :: forall (m :: Nat). a_aYH -> p1_a156 m -> p1_a156 ('S m)) ->
                     f_aYF @ p1_a156 ds_d1io ds1_d1ip))
      = f_aYF @ p_aYG nil_aYD cons_aYE

It appears that the issue is that the argument to buildVect needs to be a lambda abstraction of a very specific form, and I'm having trouble constructing a system of rewrites where that ends up happening.

How do I get fusion to work?

(I don't know if this is useful or even correct; I'm just doing this to see if I can.)

like image 221
HTNW Avatar asked Nov 07 '22 05:11

HTNW


1 Answers

As usual, newtypes save the day whenever the compiler is being bullheaded:

module Vect where
-- everything else the same...
newtype VectBuilder n a = VectBuilder { runVectBuilder :: forall (p :: Nat -> Type).
                                                          p Z ->
                                                          (forall m. a -> p m -> p (S m)) ->
                                                          p n
                                      }

buildVect' :: forall n a. VectBuilder n a -> Vect n a
buildVect' f = unFVect $
                runVectBuilder f (FVect VNil) $ \x (FVect xs) -> FVect $ x `VCons` xs
{-# INLINE [0] buildVect' #-}
buildVect :: forall n a.
             (forall (p :: Nat -> Type).
              p Z ->
              (forall m. a -> p m -> p (S m)) ->
              p n
             ) -> Vect n a
buildVect f = buildVect' (VectBuilder f)
{-# INLINE buildVect #-}

{-# RULES "foldVect/buildVect'" forall (nil :: p Z)
                                       (cons :: forall m. a -> p m -> p (S m))
                                       (f :: VectBuilder n a).
                                foldVect nil cons (buildVect' f) = runVectBuilder f nil cons
  #-}
-- compiler no longer has a chance to muck up the LHS by eta expanding f because
-- f "isn't" a function anymore

-- rule for mapVect goes unchanged, so I guess that's evidence that this is totally transparent
module Test where
import Vect
test :: Vect n Int -> Vect n Int
test = fmap (*5) . fmap (+2)
Rec {
-- RHS size: {terms: 19, types: 31, coercions: 13, joins: 0/0}
Test.test_go [Occ=LoopBreaker]
  :: forall (l :: Nat). Vect l Int -> FVect Int l
[GblId, Arity=1, Caf=NoCafRefs, Str=<S,1*U>]
Test.test_go
  = \ (@ (l_a14W :: Nat)) (ds_d1jk :: Vect l_a14W Int) ->
      case ds_d1jk of {
        VNil co_a14Y -> (Vect.$WVNil @ Int) `cast` <Co:4>;
        VCons @ n2_a151 co_a152 x_aYE xs_aYF ->
          (Vect.VCons
             @ ('S n2_a151)
             @ Int
             @ n2_a151
             @~ <Co:2>
             (case x_aYE of { GHC.Types.I# x1_a1xr ->
              GHC.Types.I# (GHC.Prim.*# (GHC.Prim.+# x1_a1xr 2#) 5#) -- success!
              })
             ((Test.test_go @ n2_a151 xs_aYF) `cast` <Co:3>))
          `cast` <Co:4>
      }
end Rec }

-- RHS size: {terms: 4, types: 5, coercions: 0, joins: 0/0}
Test.test1 :: forall (n :: Nat). Vect n Int -> FVect Int n
[GblId,
 Arity=1,
 Caf=NoCafRefs,
 Str=<S,1*U>,
 Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
         WorkFree=True, Expandable=True,
         Guidance=ALWAYS_IF(arity=1,unsat_ok=True,boring_ok=True)}]
Test.test1
  = \ (@ (n_a1wd :: Nat)) (x_X1xa :: Vect n_a1wd Int) ->
      Test.test_go @ n_a1wd x_X1xa

-- RHS size: {terms: 1, types: 0, coercions: 9, joins: 0/0}
test :: forall (n :: Nat). Vect n Int -> Vect n Int
[GblId,
 Arity=1,
 Caf=NoCafRefs,
 Str=<S,1*U>,
 Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
         WorkFree=True, Expandable=True,
         Guidance=ALWAYS_IF(arity=0,unsat_ok=True,boring_ok=True)}]
test = Test.test1 `cast` <Co:9>

Moral of the story: sufficiently highly ranked types make the RULES system implode, so give GHC some help with newtypes, even if they aren't otherwise necessary.

like image 198
HTNW Avatar answered Dec 09 '22 09:12

HTNW