I have the standard Vect
type:
module Vect where
data Nat = Z | S Nat
data Vect (n :: Nat) (a :: Type) where
VNil :: Vect Z a
VCons :: a -> Vect n a -> Vect (S n) a
And I have these functions on it:
foldVect :: forall (p :: Nat -> Type) n a.
p Z ->
(forall m. a -> p m -> p (S m)) ->
Vect n a -> p n
foldVect n c = go
where go :: forall l. Vect l a -> p l
go VNil = n
go (VCons x xs) = x `c` go xs
newtype FVect a n = FVect { unFVect :: Vect n a }
buildVect :: forall n a.
(forall (p :: Nat -> Type).
p Z ->
(forall m. a -> p m -> p (S m)) ->
p n
) -> Vect n a
buildVect f = unFVect $ f (FVect VNil) $ \x (FVect xs) -> FVect $ x `VCons` xs
I attempted to recreate (part of) the machinery from base
that allows for list fusion:
instance Functor (Vect n) where
fmap = mapVect
{-# INLINE fmap #-}
mapVect :: forall n a b. (a -> b) -> (Vect n a -> Vect n b)
mapVect _f VNil = VNil
mapVect f (VCons x xs) = f x `VCons` mapVect f xs
mapFB :: forall (p :: Nat -> Type) n a b. (forall m. b -> p m -> p (S m)) -> (a -> b) -> a -> p n -> p (S n)
mapFB cons f = \x ys -> cons (f x) ys
{-# INLINE [0] mapFB #-}
{-# NOINLINE [0] mapVect #-}
{-# RULES "mapVect" [~1] forall f xs. mapVect f xs = buildVect (\nil cons -> foldVect nil (mapFB cons f) xs) #-}
{-# INLINE [0] foldVect #-}
-- base has this; I don't think it does anything without a "refolding" rule on mapVect
{-# INLINE [0] buildVect #-}
{-# RULES "foldVect/buildVect" forall (nil :: p Z)
(cons :: forall m. a -> p m -> p (S m))
(f :: forall (q :: Nat -> Type).
q Z ->
(forall m. a -> q m -> q (S m)) ->
q n
).
foldVect nil cons (buildVect f) = f nil cons
#-}
And then
module Test where
import Vect
test :: Vect n Int -> Vect n Int
test = fmap (*5) . fmap (+2)
No fusion happens. If I pass -ddump-simpl
, I see that foldVect
and buildVect
have already been inlined, but...
INLINE
s are phased so that they don't interfere with fusion anyway. (After all, this is how base
does it for []
)Replacing the INLINE [0]
s with NOINLINE
paints a rather stunning image:
test
= \ (@ (n_a141 :: Nat)) (x_X1lK :: Vect n_a141 Int) ->
buildVect
@ n_a141
@ Int
(\ (@ (p_X1jl :: Nat -> *))
(nil_X11K [OS=OneShot] :: p_X1jl 'Z)
(cons_X11M [OS=OneShot]
:: forall (m :: Nat). Int -> p_X1jl m -> p_X1jl ('S m)) ->
foldVect
@ p_X1jl
@ n_a141
@ Int
nil_X11K
(\ (@ (m_a1i5 :: Nat))
(x1_aYI :: Int)
(ys_aYJ [OS=OneShot] :: p_X1jl m_a1i5) ->
cons_X11M
@ m_a1i5
(case x1_aYI of { GHC.Types.I# x2_a1l5 ->
GHC.Types.I# (GHC.Prim.*# x2_a1l5 5#)
})
ys_aYJ)
(buildVect
@ n_a141
@ Int
(\ (@ (p1_a1i0 :: Nat -> *))
(nil1_a10o [OS=OneShot] :: p1_a1i0 'Z)
(cons1_a10p [OS=OneShot]
:: forall (m :: Nat). Int -> p1_a1i0 m -> p1_a1i0 ('S m)) ->
foldVect
@ p1_a1i0
@ n_a141
@ Int
nil1_a10o
(\ (@ (m_a1i5 :: Nat))
(x1_aYI :: Int)
(ys_aYJ [OS=OneShot] :: p1_a1i0 m_a1i5) ->
cons1_a10p
@ m_a1i5
(case x1_aYI of { GHC.Types.I# x2_a1lh ->
GHC.Types.I# (GHC.Prim.+# x2_a1lh 2#)
})
ys_aYJ)
x_X1lK)))
Everything is right there, but the simplifier is just not having it.
If I inspect the rule itself, I see this
"foldVect/buildVect"
forall (@ (p_aYG :: Nat -> *))
(@ (n_aYJ :: Nat))
(@ a_aYH)
(nil_aYD :: p_aYG 'Z)
(cons_aYE :: forall (m :: Nat). a_aYH -> p_aYG m -> p_aYG ('S m))
(f_aYF
:: forall (q :: Nat -> *).
q 'Z -> (forall (m :: Nat). a_aYH -> q m -> q ('S m)) -> q n_aYJ).
foldVect @ p_aYG
@ n_aYJ
@ a_aYH
nil_aYD
cons_aYE
(buildVect
@ n_aYJ
@ a_aYH
(\ (@ (p1_a156 :: Nat -> *))
(ds_d1io :: p1_a156 'Z)
(ds1_d1ip
:: forall (m :: Nat). a_aYH -> p1_a156 m -> p1_a156 ('S m)) ->
f_aYF @ p1_a156 ds_d1io ds1_d1ip))
= f_aYF @ p_aYG nil_aYD cons_aYE
It appears that the issue is that the argument to buildVect
needs to be a lambda abstraction of a very specific form, and I'm having trouble constructing a system of rewrites where that ends up happening.
How do I get fusion to work?
(I don't know if this is useful or even correct; I'm just doing this to see if I can.)
As usual, newtype
s save the day whenever the compiler is being bullheaded:
module Vect where
-- everything else the same...
newtype VectBuilder n a = VectBuilder { runVectBuilder :: forall (p :: Nat -> Type).
p Z ->
(forall m. a -> p m -> p (S m)) ->
p n
}
buildVect' :: forall n a. VectBuilder n a -> Vect n a
buildVect' f = unFVect $
runVectBuilder f (FVect VNil) $ \x (FVect xs) -> FVect $ x `VCons` xs
{-# INLINE [0] buildVect' #-}
buildVect :: forall n a.
(forall (p :: Nat -> Type).
p Z ->
(forall m. a -> p m -> p (S m)) ->
p n
) -> Vect n a
buildVect f = buildVect' (VectBuilder f)
{-# INLINE buildVect #-}
{-# RULES "foldVect/buildVect'" forall (nil :: p Z)
(cons :: forall m. a -> p m -> p (S m))
(f :: VectBuilder n a).
foldVect nil cons (buildVect' f) = runVectBuilder f nil cons
#-}
-- compiler no longer has a chance to muck up the LHS by eta expanding f because
-- f "isn't" a function anymore
-- rule for mapVect goes unchanged, so I guess that's evidence that this is totally transparent
module Test where
import Vect
test :: Vect n Int -> Vect n Int
test = fmap (*5) . fmap (+2)
Rec {
-- RHS size: {terms: 19, types: 31, coercions: 13, joins: 0/0}
Test.test_go [Occ=LoopBreaker]
:: forall (l :: Nat). Vect l Int -> FVect Int l
[GblId, Arity=1, Caf=NoCafRefs, Str=<S,1*U>]
Test.test_go
= \ (@ (l_a14W :: Nat)) (ds_d1jk :: Vect l_a14W Int) ->
case ds_d1jk of {
VNil co_a14Y -> (Vect.$WVNil @ Int) `cast` <Co:4>;
VCons @ n2_a151 co_a152 x_aYE xs_aYF ->
(Vect.VCons
@ ('S n2_a151)
@ Int
@ n2_a151
@~ <Co:2>
(case x_aYE of { GHC.Types.I# x1_a1xr ->
GHC.Types.I# (GHC.Prim.*# (GHC.Prim.+# x1_a1xr 2#) 5#) -- success!
})
((Test.test_go @ n2_a151 xs_aYF) `cast` <Co:3>))
`cast` <Co:4>
}
end Rec }
-- RHS size: {terms: 4, types: 5, coercions: 0, joins: 0/0}
Test.test1 :: forall (n :: Nat). Vect n Int -> FVect Int n
[GblId,
Arity=1,
Caf=NoCafRefs,
Str=<S,1*U>,
Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
WorkFree=True, Expandable=True,
Guidance=ALWAYS_IF(arity=1,unsat_ok=True,boring_ok=True)}]
Test.test1
= \ (@ (n_a1wd :: Nat)) (x_X1xa :: Vect n_a1wd Int) ->
Test.test_go @ n_a1wd x_X1xa
-- RHS size: {terms: 1, types: 0, coercions: 9, joins: 0/0}
test :: forall (n :: Nat). Vect n Int -> Vect n Int
[GblId,
Arity=1,
Caf=NoCafRefs,
Str=<S,1*U>,
Unf=Unf{Src=<vanilla>, TopLvl=True, Value=True, ConLike=True,
WorkFree=True, Expandable=True,
Guidance=ALWAYS_IF(arity=0,unsat_ok=True,boring_ok=True)}]
test = Test.test1 `cast` <Co:9>
Moral of the story: sufficiently highly ranked types make the RULES
system implode, so give GHC some help with newtype
s, even if they aren't otherwise necessary.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With