Is it possible to get `-=` working with literals?



Today I found this post on Quora, which claimed that

factorial(n) = def $ do    
    assert (n<=0) "Negative factorial"    
    ret <- var 1    
    i <- var n    
    while i $ do    
        ret *= i    
        i -= 1
    return ret

could be correct Haskell code. I got curious, and ended up with

factorial :: Integer -> Integer
factorial n = def $ do
  assert (n >= 0) "Negative factorial"
  ret <- var 1
  i   <- var n
  while i $ do
      ret *= i
      i   -= 1
  return ret

using var = newSTRef, canonical definitions for def, assert and while, and

a *= b = readSTRef b >>= \b -> modifySTRef a ((*) b)
a -= b = modifySTRef a ((+) (negate b))

However, (*=) and (-=) have different types:

(-=) :: Num a => STRef s a -> a -> ST s ()
(*=) :: Num a => STRef s a -> STRef s a -> ST s ()

So ret -= i wouldn't work. I've tried to create a fitting type class for this:

class (Monad m) => NumMod l r m where
  (+=) :: l -> r -> m ()
  (-=) :: l -> r -> m ()
  (*=) :: l -> r -> m ()

instance Num a => NumMod (STRef s a) (STRef s a) (ST s) where
  a += b    = readSTRef b >>= \b -> modifySTRef a ((+) b)
  a -= b    = readSTRef b >>= \b -> modifySTRef a ((+) (negate b))
  a *= b    = readSTRef b >>= \b -> modifySTRef a ((*) b)

instance (Num a) => NumMod (STRef s a) a (ST s) where
  a += b    = modifySTRef a ((+) (b))
  a -= b    = modifySTRef a ((+) (negate b))
  a *= b    = modifySTRef a ((*) (b))

That actually works, but only as long as factorial returns an Integer. As soon as I change the return type to something else it fails. I've tried to create another instance

instance (Num a, Integral b) => NumMod (STRef s a) b (ST s) where
  a += b    = modifySTRef a ((+) (fromIntegral $ b))
  a -= b    = modifySTRef a ((+) (negate . fromIntegral $ b))
  a *= b    = modifySTRef a ((*) (fromIntegral b))

which fails due to overlapping instances.

Is it actually possible to create a fitting typeclass and instances to get the factorial running for any Integral a? Or will this problem always occur?

The idea

Idea is simple: wrap STRef s a in a new data type and make it an instance of Num.


First, we'll need only one pragma:

{-# LANGUAGE RankNTypes #-}

import Data.STRef    (STRef, newSTRef, readSTRef, modifySTRef)
import Control.Monad (when)
import Control.Monad.ST (ST, runST)

Wrapper for STRef:

data MyRef s a
  = MySTRef (STRef s a)  -- reference (can modify)
  | MyVal a              -- pure value (modifications are ignored)

instance Num a => Num (MyRef s a) where
  fromInteger = MyVal . fromInteger

A few helpers for MyRef to resemble STRef functions:

newMyRef :: a -> ST s (MyRef s a)
newMyRef x = do
  ref <- newSTRef x
  return (MySTRef ref)

readMyRef :: MyRef s a -> ST s a
readMyRef (MySTRef x) = readSTRef x
readMyRef (MyVal   x) = return x

I'd like to implement -= and *= using a bit more general alter helper:

alter :: (a -> a -> a) -> MyRef s a -> MyRef s a -> ST s ()
alter f (MySTRef x) (MySTRef y) = readSTRef y >>= modifySTRef x . flip f
alter f (MySTRef x) (MyVal   y) = modifySTRef x (flip f y)
alter _ _ _ = return ()

(-=) :: Num a => MyRef s a -> MyRef s a -> ST s ()
(-=) = alter (-)

(*=) :: Num a => MyRef s a -> MyRef s a -> ST s ()
(*=) = alter (*)

Other functions are almost unchanged:

var :: a -> ST s (MyRef s a)
var = newMyRef

def :: (forall s. ST s (MyRef s a)) -> a
def m = runST $ m >>= readMyRef

while :: (Num a, Ord a) => MyRef s a -> ST s () -> ST s ()
while i m = go
    go = do
      n <- readMyRef i
      when (n > 0) $ m >> go

assert :: Monad m => Bool -> String -> m ()
assert b str = when (not b) $ error str

factorial :: Integral a => a -> a
factorial n = def $ do
    assert (n >= 0) "Negative factorial"
    ret <- var 1
    i   <- var n
    while i $ do
      ret *= i
      i -= 1
    return ret

main :: IO ()
main = print . factorial $ 1000


Making Num instances like this feels a bit hacky, but we don't have FromInteger type class in Haskell, so I guess it's OK.

Another itchy thing is 3 *= 10 which is return (). I think it is possible to use phantom type to indicate whether MyRef is ST or pure and allow only ST on the LHS of alter.

