I am implementing R7RS-small Scheme and I have encountered the following problem with the implementation of equal?: (as should be obvious) equal? tests value equality, and it furthermore is able to test the equality of cyclic data structures without getting into infinite loops. However, because I am implementing Scheme in Haskell, I do not have access to underlying pointer values that can casted to integers be used in a hash table* or search tree structure to track which nodes I have already followed (so as to be able to efficiently prune paths that would result in infinite loops).
Rather, all I seemingly have to work with is equality of identity (as measured by (==) upon IOArrays underlying pairs, vectors, and records), and hence seemingly all I can do is construct lists marking which nodes I have followed (separated by type), and then for each further node I follow search the appropriate list for nodes I have already followed, which from what it seems to me scales in O(n log n) in time and O(n) in space.
Am I right that, given these conditions this is the only algorithm available to me, or are there other more efficient implementations I am missing?
I have considered tagging every value that can contain references with a tag that could be used in a search tree or hash table*, but the problem here is that this would be particularly space-inefficient for lists, as I need to use two words for the tag for every node, one being the ThreadId and one being a per-thread unique ID (the ThreadId is necessary because, as I am doing a multithreaded implementation of Scheme, I would otherwise have to protect a shared unique ID counter behind an MVar or TMVar, which would have horrible contention in many use cases).
* As I am implementing everything in a monad transformer that implements MonadIO, traditional imperative-style hash tables are available to me.
Wouldn't Tortoise and Hare be able to fix this?
In a single list it's trivial. You let the hare step twice as fast as the tortoise and start 1 ahead of the first element. If the hare ever matches the tortoise you have a cycle.
With cons cells it's basically a binary tree, and you can traverse the tree in one particular order with both trees and the hare follows the first at double speed. If elements are eq?, atoms not-eqv? you shot circuit. If tortoise and hare match you backtrack.
Here is the algorithm I have figured out to implement this with. It is a variation on Brent's "teleporting turtle" algorithm, modified to handle not a linear list of nodes but an N-branching tree of nodes.
(This is not taking the actual comparison into account. There will be two copies of the state below, one for each data structure being tested for equality, and if something is not found to be equal in value, the comparison is short-circuited and false is returned.)
I maintain two stacks, a stack of nodes that I have followed over a depth-first traversal combined with the next node to follow at the same depth and a current depth value, and a stack of nodes that the turtle is going to be positioned at which record the depth the turtle is at and the distance deeper than the turtle the next turtle will be at. (In my actual implementation, the stacks are unified so that every stack frame points at a pair of nodes and a turtle (which points at a pair of nodes), which simplifies the management of the turtles.)
As I traverse the data structure depth-first, I build up the first stack, and at intervals of increasing powers of two distance in traversal I add new frames onto the turtle stack, where the turtle is pointing at the current node on the first stack.
When I reach a node where I can go no deeper, because it has no sibling nodes that have not been reached yet, I descend the first stack until I reach a node that does have an unchecked sibling, and then replace that node with the next sibling node to follow; if there are no sibling nodes left to follow anywhere in the stack, then we terminate with true for value equality.
Note that when descending the first stack, if the top of the first stack being popped off equals the same depth (or node) as the top of the turtle stack, the top of turtle stack is popped off.
If after pushing a frame onto the first stack the current node is equal to the node at the top of the turtle stack, I backtrack. The difference in depth between the top of the first stack and the the top of the turtle stack is equal to the size of the cycle. I backtrack a full cycle, recording each node I pass and its corresponding stack states and siblings. Then I test the nodes in the frame on the first stack below the topmost frame. If they are not in the recorded nodes then I know that the node I am at is the start of the cycle; then I pull the recorded stacks and siblings for that node and continue from there, so I can take alternative paths from within the cycle (remember this is an N-branching tree) or otherwise descend out of the cycle. If they are in the recorded nodes I update the recorded nodes to contain the stacks below the topmost frames and the siblings of the current node, and then pop the tops of the stacks and continue.
Here is my code for a test implementation of the algorithm. The code should work now.
{-# LANGUAGE RecordWildCards, BangPatterns #-}
module EqualTree (Tree(..),
equal)
where
import Data.Array.IO (IOArray)
import Data.Array.MArray (readArray,
getBounds)
data Tree a = Value a | Node (Node a)
type Node a = IOArray Int (Tree a)
data Frame a = Frame { frameNodes :: !(Node a, Node a),
frameSiblings :: !(Maybe (Siblings a)),
frameTurtle :: !(Turtle a) }
data Siblings a = Siblings { siblingNodes :: !(Node a, Node a),
siblingIndex :: !Int }
data Turtle a = Turtle { turtleDepth :: !Int,
turtleScale :: !Int,
turtleNodes :: !(Node a, Node a) }
data EqState a = EqState { stateFrames :: [Frame a],
stateCycles :: [(Node a, Node a)],
stateDepth :: !Int }
data Unrolled a = Unrolled { unrolledNodes :: !(Node a, Node a),
unrolledState :: !(EqState a),
unrolledSiblings :: !(Maybe (Siblings a)) }
data NodeComparison = EqualNodes | NotEqualNodes | HalfEqualNodes
equal :: Eq a => Tree a -> Tree a -> IO Bool
equal tree0 tree1 =
let state = EqState { stateFrames = [], stateCycles = [], stateDepth = 0 }
in ascend state tree0 tree1 Nothing
ascend :: Eq a => EqState a -> Tree a -> Tree a -> Maybe (Siblings a) -> IO Bool
ascend state (Value value0) (Value value1) siblings =
if value0 == value1
then descend state siblings
else return False
ascend state (Node node0) (Node node1) siblings =
case memberNodes (node0, node1) (stateCycles state) of
EqualNodes -> descend state siblings
HalfEqualNodes -> return False
NotEqualNodes -> do
(_, bound0) <- getBounds node0
(_, bound1) <- getBounds node1
if bound0 == bound1
then
let turtleNodes = currentTurtleNodes state
state' = state { stateFrames =
newFrame state node0 node1 siblings :
stateFrames state,
stateDepth = (stateDepth state) + 1 }
checkDepth = nextTurtleDepth state'
in case turtleNodes of
Just turtleNodes' ->
case equalNodes (node0, node1) turtleNodes' of
EqualNodes -> beginRecovery state node0 node1 siblings
HalfEqualNodes -> return False
NotEqualNodes -> ascendFirst state' node0 node1
Nothing -> ascendFirst state' node0 node1
else return False
ascend _ _ _ _ = return False
ascendFirst :: Eq a => EqState a -> Node a -> Node a -> IO Bool
ascendFirst state node0 node1 = do
(_, bound) <- getBounds node0
tree0 <- readArray node0 0
tree1 <- readArray node1 0
if bound > 0
then let siblings = Siblings { siblingNodes = (node0, node1),
siblingIndex = 1 }
in ascend state tree0 tree1 (Just siblings)
else ascend state tree0 tree1 Nothing
descend :: Eq a => EqState a -> Maybe (Siblings a) -> IO Bool
descend state Nothing =
case stateFrames state of
[] -> return True
frame : rest ->
let state' = state { stateFrames = rest,
stateDepth = stateDepth state - 1 }
in descend state' (frameSiblings frame)
descend state (Just Siblings{..}) = do
let (node0, node1) = siblingNodes
(_, bound) <- getBounds node0
tree0 <- readArray node0 siblingIndex
tree1 <- readArray node1 siblingIndex
if siblingIndex < bound
then let siblings' = Siblings { siblingNodes = (node0, node1),
siblingIndex = siblingIndex + 1 }
in ascend state tree0 tree1 (Just siblings')
else ascend state tree0 tree1 Nothing
beginRecovery :: Eq a => EqState a -> Node a -> Node a -> Maybe (Siblings a)
-> IO Bool
beginRecovery state node0 node1 siblings =
let turtle = case stateFrames state of
[] -> error "must have first frame in stack"
frame : _ -> frameTurtle frame
distance = (stateDepth state + 1) - turtleDepth turtle
unrolledFrame = Unrolled { unrolledNodes = (node0, node1),
unrolledState = state,
unrolledSiblings = siblings }
in unrolledFrame `seq` unrollCycle state [unrolledFrame] (distance - 1)
unrollCycle :: Eq a => EqState a -> [Unrolled a] -> Int -> IO Bool
unrollCycle state unrolled !count
| count <= 0 = findCycleStart state unrolled
| otherwise =
case stateFrames state of
[] -> error "frame must be found"
frame : rest ->
let state' = state { stateFrames = rest,
stateDepth = stateDepth state - 1 }
unrolledFrame =
Unrolled { unrolledNodes = frameNodes frame,
unrolledState = state',
unrolledSiblings = frameSiblings frame }
in unrolledFrame `seq`
unrollCycle state' (unrolledFrame : unrolled) (count - 1)
findCycleStart :: Eq a => EqState a -> [Unrolled a] -> IO Bool
findCycleStart state unrolled =
case stateFrames state of
[] ->
return True
frame : [] ->
case memberUnrolled (frameNodes frame) unrolled of
(NotEqualNodes, _) -> error "node not in nodes unrolled"
(HalfEqualNodes, _) -> return False
(EqualNodes, Just (state, siblings)) ->
let state' =
state { stateCycles = frameNodes frame : stateCycles state }
in state' `seq` descend state' siblings
frame : rest@(prevFrame : _) ->
case memberUnrolled (frameNodes prevFrame) unrolled of
(EqualNodes, _) ->
let state' = state { stateFrames = rest,
stateDepth = stateDepth state - 1 }
unrolledFrame =
Unrolled { unrolledNodes = frameNodes frame,
unrolledState = state',
unrolledSiblings = frameSiblings frame }
unrolled' = updateUnrolled unrolledFrame unrolled
in unrolledFrame `seq` findCycleStart state' unrolled'
(HalfEqualNodes, _) -> return False
(NotEqualNodes, _) ->
case memberUnrolled (frameNodes frame) unrolled of
(NotEqualNodes, _) -> error "node not in nodes unrolled"
(HalfEqualNodes, _) -> return False
(EqualNodes, Just (state, siblings)) ->
let state' =
state { stateCycles = frameNodes frame : stateCycles state }
in state' `seq` descend state' siblings
updateUnrolled :: Unrolled a -> [Unrolled a] -> [Unrolled a]
updateUnrolled _ [] = []
updateUnrolled unrolled0 (unrolled1 : rest) =
case equalNodes (unrolledNodes unrolled0) (unrolledNodes unrolled1) of
EqualNodes -> unrolled0 : rest
NotEqualNodes -> unrolled1 : updateUnrolled unrolled0 rest
HalfEqualNodes -> error "this should not be possible"
memberUnrolled :: (Node a, Node a) -> [Unrolled a] ->
(NodeComparison, Maybe (EqState a, Maybe (Siblings a)))
memberUnrolled _ [] = (NotEqualNodes, Nothing)
memberUnrolled nodes (Unrolled{..} : rest) =
case equalNodes nodes unrolledNodes of
EqualNodes -> (EqualNodes, Just (unrolledState, unrolledSiblings))
HalfEqualNodes -> (HalfEqualNodes, Nothing)
NotEqualNodes -> memberUnrolled nodes rest
newFrame :: EqState a -> Node a -> Node a -> Maybe (Siblings a) -> Frame a
newFrame state node0 node1 siblings =
let turtle =
if (stateDepth state + 1) == nextTurtleDepth state
then Turtle { turtleDepth = stateDepth state + 1,
turtleScale = currentTurtleScale state * 2,
turtleNodes = (node0, node1) }
else case stateFrames state of
[] -> Turtle { turtleDepth = 1, turtleScale = 2,
turtleNodes = (node0, node1) }
frame : _ -> frameTurtle frame
in Frame { frameNodes = (node0, node1),
frameSiblings = siblings,
frameTurtle = turtle }
memberNodes :: (Node a, Node a) -> [(Node a, Node a)] -> NodeComparison
memberNodes _ [] = NotEqualNodes
memberNodes nodes0 (nodes1 : rest) =
case equalNodes nodes0 nodes1 of
NotEqualNodes -> memberNodes nodes0 rest
HalfEqualNodes -> HalfEqualNodes
EqualNodes -> EqualNodes
equalNodes :: (Node a, Node a) -> (Node a, Node a) -> NodeComparison
equalNodes (node0, node1) (node2, node3) =
if node0 == node2
then if node1 == node3
then EqualNodes
else HalfEqualNodes
else if node1 == node3
then HalfEqualNodes
else NotEqualNodes
currentTurtleNodes :: EqState a -> Maybe (Node a, Node a)
currentTurtleNodes state =
case stateFrames state of
[] -> Nothing
frame : _ -> Just . turtleNodes . frameTurtle $ frame
currentTurtleScale :: EqState a -> Int
currentTurtleScale state =
case stateFrames state of
[] -> 1
frame : _ -> turtleScale $ frameTurtle frame
nextTurtleDepth :: EqState a -> Int
nextTurtleDepth state =
case stateFrames state of
[] -> 1
frame : _ -> let turtle = frameTurtle frame
in turtleDepth turtle + turtleScale turtle
Here is a naive version of the algorithm used by the test program.
{-# LANGUAGE RecordWildCards #-}
module NaiveEqualTree (Tree(..),
naiveEqual)
where
import Data.Array.IO (IOArray)
import Data.Array.MArray (readArray,
getBounds)
import EqualTree (Tree(..),
Node)
data Frame a = Frame { frameNodes :: !(Node a, Node a),
frameSiblings :: !(Maybe (Siblings a)) }
data Siblings a = Siblings { siblingNodes :: !(Node a, Node a),
siblingIndex :: !Int }
data NodeComparison = EqualNodes | NotEqualNodes | HalfEqualNodes
naiveEqual :: Eq a => Tree a -> Tree a -> IO Bool
naiveEqual tree0 tree1 = ascend [] tree0 tree1 Nothing
ascend :: Eq a => [Frame a] -> Tree a -> Tree a -> Maybe (Siblings a) -> IO Bool
ascend state (Value value0) (Value value1) siblings =
if value0 == value1
then descend state siblings
else return False
ascend state (Node node0) (Node node1) siblings =
case testNodes (node0, node1) state of
EqualNodes -> descend state siblings
HalfEqualNodes -> return False
NotEqualNodes -> do
(_, bound0) <- getBounds node0
(_, bound1) <- getBounds node1
if bound0 == bound1
then do
let frame = Frame { frameNodes = (node0, node1),
frameSiblings = siblings }
state' = frame : state
tree0 <- readArray node0 0
tree1 <- readArray node1 0
if bound0 > 0
then let siblings = Siblings { siblingNodes = (node0, node1),
siblingIndex = 1 }
in frame `seq` ascend state' tree0 tree1 (Just siblings)
else frame `seq` ascend state' tree0 tree1 Nothing
else return False
ascend _ _ _ _ = return False
descend :: Eq a => [Frame a] -> Maybe (Siblings a) -> IO Bool
descend state Nothing =
case state of
[] -> return True
frame : rest -> descend rest (frameSiblings frame)
descend state (Just Siblings{..}) = do
let (node0, node1) = siblingNodes
(_, bound) <- getBounds node0
tree0 <- readArray node0 siblingIndex
tree1 <- readArray node1 siblingIndex
if siblingIndex < bound
then let siblings' = Siblings { siblingNodes = (node0, node1),
siblingIndex = siblingIndex + 1 }
in ascend state tree0 tree1 (Just siblings')
else ascend state tree0 tree1 Nothing
testNodes :: (Node a, Node a) -> [Frame a] -> NodeComparison
testNodes _ [] = NotEqualNodes
testNodes nodes (frame : rest) =
case equalNodes nodes (frameNodes frame) of
NotEqualNodes -> testNodes nodes rest
HalfEqualNodes -> HalfEqualNodes
EqualNodes -> EqualNodes
equalNodes :: (Node a, Node a) -> (Node a, Node a) -> NodeComparison
equalNodes (node0, node1) (node2, node3) =
if node0 == node2
then if node1 == node3
then EqualNodes
else HalfEqualNodes
else if node1 == node3
then HalfEqualNodes
else NotEqualNodes
Here is the code of the test program. Note that this will occasionally fail on the not-equals test because it is designed to generate sets of nodes with a significant degree of commonality, as controlled by commonPortionRange
.
{-# LANGUAGE TupleSections #-}
module Main where
import Data.Array (Array,
listArray,
bounds,
(!))
import Data.Array.IO (IOArray)
import Data.Array.MArray (writeArray,
newArray_)
import Control.Monad (forM_,
mapM,
mapM_,
liftM,
foldM)
import Control.Exception (SomeException,
catch)
import System.Random (StdGen,
newStdGen,
random,
randomR,
split)
import Prelude hiding (catch)
import EqualTree (Tree(..),
equal)
import NaiveEqualTree (naiveEqual)
leafChance :: Double
leafChance = 0.5
valueCount :: Int
valueCount = 1
maxNodeCount :: Int
maxNodeCount = 1024
commonPortionRange :: (Double, Double)
commonPortionRange = (0.8, 0.9)
commonRootChance :: Double
commonRootChance = 0.5
nodeSizeRange :: (Int, Int)
nodeSizeRange = (2, 5)
testCount :: Int
testCount = 1000
makeMapping :: Int -> (Int, Int) -> Int -> StdGen ->
([Either Int Int], StdGen)
makeMapping values range nodes gen =
let (count, gen') = randomR range gen
in makeMapping' 0 [] count gen'
where makeMapping' index mapping count gen
| index >= count = (mapping, gen)
| otherwise =
let (chance, gen0) = random gen
(slot, gen2) =
if chance <= leafChance
then let (value, gen1) = randomR (0, values - 1) gen0
in (Left value, gen1)
else let (nodeIndex, gen1) = randomR (0, nodes - 1) gen0
in (Right nodeIndex, gen1)
in makeMapping' (index + 1) (slot : mapping) count gen2
makeMappings :: Int -> Int -> (Int, Int) -> StdGen ->
([[Either Int Int]], StdGen)
makeMappings size values range gen =
let (size', gen') = randomR (1, size) gen
in makeMappings' 0 size' [] gen'
where makeMappings' index size mappings gen
| index >= size = (mappings, gen)
| otherwise =
let (mapping, gen') = makeMapping values range size gen
in makeMappings' (index + 1) size (mapping : mappings) gen'
makeMappingsPair :: Int -> (Double, Double) -> Int -> (Int, Int) -> StdGen ->
([[Either Int Int]], [[Either Int Int]], StdGen)
makeMappingsPair size commonPortionRange values range gen =
let (size', gen0) = randomR (2, size) gen
(commonPortion, gen1) = randomR commonPortionRange gen0
size0 = 1 + (floor $ fromIntegral size' * commonPortion)
size1 = size' - size0
(mappings, gen2) = makeMappingsPair' 0 size0 size' [] gen1
(mappings0, gen3) = makeMappingsPair' 0 size1 size' [] gen2
(mappings1, gen4) = makeMappingsPair' 0 size1 size' [] gen3
(commonRootValue, gen5) = random gen4
in if commonRootValue < commonRootChance
then (mappings ++ mappings0, mappings ++ mappings1, gen5)
else (mappings0 ++ mappings, mappings1 ++ mappings, gen5)
where makeMappingsPair' index size size' mappings gen
| index >= size = (mappings, gen)
| otherwise =
let (mapping, gen') = makeMapping values range size' gen
in makeMappingsPair' (index + 1) size size' (mapping : mappings)
gen'
populateNode :: IOArray Int (Tree a) -> Array Int (IOArray Int (Tree a)) ->
[Either a Int] -> IO ()
populateNode node nodes mapping =
mapM_ (uncurry populateSlot) (zip [0..] mapping)
where populateSlot index (Left value) =
writeArray node index $ Value value
populateSlot index (Right nodeIndex) =
writeArray node index . Node $ nodes ! nodeIndex
makeTree :: [[Either a Int]] -> IO (Tree a)
makeTree mappings = do
let size = length mappings
nodes <- liftM (listArray (0, size - 1)) $ mapM makeNode mappings
mapM_ (\(index, mapping) -> populateNode (nodes ! index) nodes mapping)
(zip [0..] mappings)
return . Node $ nodes ! 0
where makeNode mapping = newArray_ (0, length mapping - 1)
testEqual :: StdGen -> IO (Bool, StdGen)
testEqual gen = do
let (mappings, gen0) =
makeMappings maxNodeCount valueCount nodeSizeRange gen
tree0 <- makeTree mappings
tree1 <- makeTree mappings
catch (liftM (, gen0) $ equal tree0 tree1) $ \e -> do
putStrLn $ show (e :: SomeException)
return (False, gen0)
testNotEqual :: StdGen -> IO (Bool, Bool, StdGen)
testNotEqual gen = do
let (mappings0, mappings1, gen0) =
makeMappingsPair maxNodeCount commonPortionRange valueCount
nodeSizeRange gen
tree0 <- makeTree mappings0
tree1 <- makeTree mappings1
test <- naiveEqual tree0 tree1
if not test
then
catch (testNotEqual' tree0 tree1 mappings0 mappings1 gen0) $ \e -> do
putStrLn $ show (e :: SomeException)
return (False, False, gen0)
else return (True, True, gen0)
where testNotEqual' tree0 tree1 mappings0 mappings1 gen0 = do
test <- equal tree0 tree1
if test
then do
putStrLn "Match failure: "
putStrLn "Mappings 0: "
mapM (putStrLn . show) $ zip [0..] mappings0
putStrLn "Mappings 1: "
mapM (putStrLn . show) $ zip [0..] mappings1
return (False, False, gen0)
else return (True, False, gen0)
doTestEqual :: (StdGen, Int) -> Int -> IO (StdGen, Int)
doTestEqual (gen, successCount) _ = do
(success, gen') <- testEqual gen
return (gen', successCount + (if success then 1 else 0))
doTestNotEqual :: (StdGen, Int, Int) -> Int -> IO (StdGen, Int, Int)
doTestNotEqual (gen, successCount, excludeCount) _ = do
(success, exclude, gen') <- testNotEqual gen
return (gen', successCount + (if success then 1 else 0),
excludeCount + (if exclude then 1 else 0))
main :: IO ()
main = do
gen <- newStdGen
(gen0, equalSuccessCount) <- foldM doTestEqual (gen, 0) [1..testCount]
putStrLn $ show equalSuccessCount ++ " out of " ++ show testCount ++
" tests for equality passed"
(_, notEqualSuccessCount, excludeCount) <-
foldM doTestNotEqual (gen0, 0, 0) [1..testCount]
putStrLn $ show notEqualSuccessCount ++ " out of " ++ show testCount ++
" tests for inequality passed (with " ++ show excludeCount ++ " excluded)"
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With