Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficiently implementing equal? in Scheme without direct access to pointer values

I am implementing R7RS-small Scheme and I have encountered the following problem with the implementation of equal?: (as should be obvious) equal? tests value equality, and it furthermore is able to test the equality of cyclic data structures without getting into infinite loops. However, because I am implementing Scheme in Haskell, I do not have access to underlying pointer values that can casted to integers be used in a hash table* or search tree structure to track which nodes I have already followed (so as to be able to efficiently prune paths that would result in infinite loops).

Rather, all I seemingly have to work with is equality of identity (as measured by (==) upon IOArrays underlying pairs, vectors, and records), and hence seemingly all I can do is construct lists marking which nodes I have followed (separated by type), and then for each further node I follow search the appropriate list for nodes I have already followed, which from what it seems to me scales in O(n log n) in time and O(n) in space.

Am I right that, given these conditions this is the only algorithm available to me, or are there other more efficient implementations I am missing?

I have considered tagging every value that can contain references with a tag that could be used in a search tree or hash table*, but the problem here is that this would be particularly space-inefficient for lists, as I need to use two words for the tag for every node, one being the ThreadId and one being a per-thread unique ID (the ThreadId is necessary because, as I am doing a multithreaded implementation of Scheme, I would otherwise have to protect a shared unique ID counter behind an MVar or TMVar, which would have horrible contention in many use cases).

* As I am implementing everything in a monad transformer that implements MonadIO, traditional imperative-style hash tables are available to me.

like image 211
Travis Bemann Avatar asked Aug 20 '13 20:08

Travis Bemann


2 Answers

Wouldn't Tortoise and Hare be able to fix this?

In a single list it's trivial. You let the hare step twice as fast as the tortoise and start 1 ahead of the first element. If the hare ever matches the tortoise you have a cycle.

With cons cells it's basically a binary tree, and you can traverse the tree in one particular order with both trees and the hare follows the first at double speed. If elements are eq?, atoms not-eqv? you shot circuit. If tortoise and hare match you backtrack.

like image 99
Sylwester Avatar answered Oct 13 '22 17:10

Sylwester


Here is the algorithm I have figured out to implement this with. It is a variation on Brent's "teleporting turtle" algorithm, modified to handle not a linear list of nodes but an N-branching tree of nodes.

(This is not taking the actual comparison into account. There will be two copies of the state below, one for each data structure being tested for equality, and if something is not found to be equal in value, the comparison is short-circuited and false is returned.)

I maintain two stacks, a stack of nodes that I have followed over a depth-first traversal combined with the next node to follow at the same depth and a current depth value, and a stack of nodes that the turtle is going to be positioned at which record the depth the turtle is at and the distance deeper than the turtle the next turtle will be at. (In my actual implementation, the stacks are unified so that every stack frame points at a pair of nodes and a turtle (which points at a pair of nodes), which simplifies the management of the turtles.)

As I traverse the data structure depth-first, I build up the first stack, and at intervals of increasing powers of two distance in traversal I add new frames onto the turtle stack, where the turtle is pointing at the current node on the first stack.

When I reach a node where I can go no deeper, because it has no sibling nodes that have not been reached yet, I descend the first stack until I reach a node that does have an unchecked sibling, and then replace that node with the next sibling node to follow; if there are no sibling nodes left to follow anywhere in the stack, then we terminate with true for value equality.

Note that when descending the first stack, if the top of the first stack being popped off equals the same depth (or node) as the top of the turtle stack, the top of turtle stack is popped off.

If after pushing a frame onto the first stack the current node is equal to the node at the top of the turtle stack, I backtrack. The difference in depth between the top of the first stack and the the top of the turtle stack is equal to the size of the cycle. I backtrack a full cycle, recording each node I pass and its corresponding stack states and siblings. Then I test the nodes in the frame on the first stack below the topmost frame. If they are not in the recorded nodes then I know that the node I am at is the start of the cycle; then I pull the recorded stacks and siblings for that node and continue from there, so I can take alternative paths from within the cycle (remember this is an N-branching tree) or otherwise descend out of the cycle. If they are in the recorded nodes I update the recorded nodes to contain the stacks below the topmost frames and the siblings of the current node, and then pop the tops of the stacks and continue.

Here is my code for a test implementation of the algorithm. The code should work now.

{-# LANGUAGE RecordWildCards, BangPatterns #-}

module EqualTree (Tree(..),
                  equal)
       where

import Data.Array.IO (IOArray)
import Data.Array.MArray (readArray,
                          getBounds)

data Tree a = Value a | Node (Node a)

type Node a = IOArray Int (Tree a)

data Frame a = Frame { frameNodes :: !(Node a, Node a),
                       frameSiblings :: !(Maybe (Siblings a)),
                       frameTurtle :: !(Turtle a) }

data Siblings a = Siblings { siblingNodes :: !(Node a, Node a),
                             siblingIndex :: !Int }

data Turtle a = Turtle { turtleDepth :: !Int,
                         turtleScale :: !Int,
                         turtleNodes :: !(Node a, Node a) }

data EqState a = EqState { stateFrames :: [Frame a],
                           stateCycles :: [(Node a, Node a)],
                           stateDepth :: !Int }

data Unrolled a = Unrolled { unrolledNodes :: !(Node a, Node a),
                             unrolledState :: !(EqState a),
                             unrolledSiblings :: !(Maybe (Siblings a)) }

data NodeComparison = EqualNodes | NotEqualNodes | HalfEqualNodes

equal :: Eq a => Tree a -> Tree a -> IO Bool
equal tree0 tree1 =
  let state = EqState { stateFrames = [], stateCycles = [], stateDepth = 0 }
  in ascend state tree0 tree1 Nothing

ascend :: Eq a => EqState a -> Tree a -> Tree a -> Maybe (Siblings a) -> IO Bool
ascend state (Value value0) (Value value1) siblings =
  if value0 == value1
  then descend state siblings
  else return False
ascend state (Node node0) (Node node1) siblings =
  case memberNodes (node0, node1) (stateCycles state) of
    EqualNodes -> descend state siblings
    HalfEqualNodes -> return False
    NotEqualNodes -> do
      (_, bound0) <- getBounds node0
      (_, bound1) <- getBounds node1
      if bound0 == bound1
        then
          let turtleNodes = currentTurtleNodes state
              state' = state { stateFrames =
                                  newFrame state node0 node1 siblings :
                                  stateFrames state,
                               stateDepth = (stateDepth state) + 1 }
              checkDepth = nextTurtleDepth state'
          in case turtleNodes of
               Just turtleNodes' -> 
                 case equalNodes (node0, node1) turtleNodes' of
                   EqualNodes -> beginRecovery state node0 node1 siblings
                   HalfEqualNodes -> return False
                   NotEqualNodes -> ascendFirst state' node0 node1
               Nothing -> ascendFirst state' node0 node1
        else return False
ascend _ _ _ _ = return False

ascendFirst :: Eq a => EqState a -> Node a -> Node a -> IO Bool
ascendFirst state node0 node1 = do
  (_, bound) <- getBounds node0
  tree0 <- readArray node0 0
  tree1 <- readArray node1 0
  if bound > 0
    then let siblings = Siblings { siblingNodes = (node0, node1),
                                   siblingIndex = 1 }
         in ascend state tree0 tree1 (Just siblings)
    else ascend state tree0 tree1 Nothing

descend :: Eq a => EqState a -> Maybe (Siblings a) -> IO Bool
descend state Nothing =
  case stateFrames state of
    [] -> return True
    frame : rest ->
      let state' = state { stateFrames = rest,
                           stateDepth = stateDepth state - 1 }
      in descend state' (frameSiblings frame)
descend state (Just Siblings{..}) = do
  let (node0, node1) = siblingNodes
  (_, bound) <- getBounds node0
  tree0 <- readArray node0 siblingIndex
  tree1 <- readArray node1 siblingIndex
  if siblingIndex < bound
    then let siblings' = Siblings { siblingNodes = (node0, node1),
                                    siblingIndex = siblingIndex + 1 }
         in ascend state tree0 tree1 (Just siblings')
    else ascend state tree0 tree1 Nothing

beginRecovery :: Eq a => EqState a -> Node a -> Node a -> Maybe (Siblings a)
                 -> IO Bool
beginRecovery state node0 node1 siblings =
  let turtle = case stateFrames state of
                 [] -> error "must have first frame in stack"
                 frame : _ -> frameTurtle frame
      distance = (stateDepth state + 1) - turtleDepth turtle
      unrolledFrame = Unrolled { unrolledNodes = (node0, node1),
                                 unrolledState = state,
                                 unrolledSiblings = siblings }
  in unrolledFrame `seq` unrollCycle state [unrolledFrame] (distance - 1)

unrollCycle :: Eq a => EqState a -> [Unrolled a] -> Int -> IO Bool
unrollCycle state unrolled !count
  | count <= 0 = findCycleStart state unrolled
  | otherwise =
      case stateFrames state of
        [] -> error "frame must be found"
        frame : rest ->
          let state' = state { stateFrames = rest,
                               stateDepth = stateDepth state - 1 }
              unrolledFrame =
                Unrolled { unrolledNodes = frameNodes frame,
                           unrolledState = state',
                           unrolledSiblings = frameSiblings frame }
          in unrolledFrame `seq`
             unrollCycle state' (unrolledFrame : unrolled) (count - 1)

findCycleStart :: Eq a => EqState a -> [Unrolled a] -> IO Bool
findCycleStart state unrolled =
  case stateFrames state of
    [] ->
      return True
    frame : [] ->
      case memberUnrolled (frameNodes frame) unrolled of
        (NotEqualNodes, _) -> error "node not in nodes unrolled"
        (HalfEqualNodes, _) -> return False
        (EqualNodes, Just (state, siblings)) ->
          let state' =
                state { stateCycles = frameNodes frame : stateCycles state }
          in state' `seq` descend state' siblings
    frame : rest@(prevFrame : _) ->
      case memberUnrolled (frameNodes prevFrame) unrolled of
        (EqualNodes, _) ->
          let state' = state { stateFrames = rest,
                               stateDepth = stateDepth state - 1 }
              unrolledFrame =
                Unrolled { unrolledNodes = frameNodes frame,
                           unrolledState = state',
                           unrolledSiblings = frameSiblings frame }
              unrolled' = updateUnrolled unrolledFrame unrolled
          in unrolledFrame `seq` findCycleStart state' unrolled'
        (HalfEqualNodes, _) -> return False
        (NotEqualNodes, _) ->
          case memberUnrolled (frameNodes frame) unrolled of
            (NotEqualNodes, _) -> error "node not in nodes unrolled"
            (HalfEqualNodes, _) -> return False
            (EqualNodes, Just (state, siblings)) ->
              let state' =
                    state { stateCycles = frameNodes frame : stateCycles state }
              in state' `seq` descend state' siblings

updateUnrolled :: Unrolled a -> [Unrolled a] -> [Unrolled a]
updateUnrolled _ [] = []
updateUnrolled unrolled0 (unrolled1 : rest) =
  case equalNodes (unrolledNodes unrolled0) (unrolledNodes unrolled1) of
    EqualNodes -> unrolled0 : rest
    NotEqualNodes -> unrolled1 : updateUnrolled unrolled0 rest
    HalfEqualNodes -> error "this should not be possible"

memberUnrolled :: (Node a, Node a) -> [Unrolled a] ->
                  (NodeComparison, Maybe (EqState a, Maybe (Siblings a)))
memberUnrolled _ [] = (NotEqualNodes, Nothing)
memberUnrolled nodes (Unrolled{..} : rest) =
  case equalNodes nodes unrolledNodes of
    EqualNodes -> (EqualNodes, Just (unrolledState, unrolledSiblings))
    HalfEqualNodes -> (HalfEqualNodes, Nothing)
    NotEqualNodes -> memberUnrolled nodes rest

newFrame :: EqState a -> Node a -> Node a -> Maybe (Siblings a) -> Frame a
newFrame state node0 node1 siblings =
  let turtle =
        if (stateDepth state + 1) == nextTurtleDepth state
        then Turtle { turtleDepth = stateDepth state + 1,
                      turtleScale = currentTurtleScale state * 2, 
                      turtleNodes = (node0, node1) }
        else case stateFrames state of
               [] -> Turtle { turtleDepth = 1, turtleScale = 2,
                              turtleNodes = (node0, node1) }
               frame : _ -> frameTurtle frame
  in Frame { frameNodes = (node0, node1),
             frameSiblings = siblings,
             frameTurtle = turtle }

memberNodes :: (Node a, Node a) -> [(Node a, Node a)] -> NodeComparison
memberNodes _ [] = NotEqualNodes
memberNodes nodes0 (nodes1 : rest) =
  case equalNodes nodes0 nodes1 of
    NotEqualNodes -> memberNodes nodes0 rest
    HalfEqualNodes -> HalfEqualNodes
    EqualNodes -> EqualNodes

equalNodes :: (Node a, Node a) -> (Node a, Node a) -> NodeComparison
equalNodes (node0, node1) (node2, node3) =
  if node0 == node2
  then if node1 == node3
       then EqualNodes
       else HalfEqualNodes
  else if node1 == node3
       then HalfEqualNodes
       else NotEqualNodes

currentTurtleNodes :: EqState a -> Maybe (Node a, Node a)
currentTurtleNodes state =
  case stateFrames state of
    [] -> Nothing
    frame : _ -> Just . turtleNodes . frameTurtle $ frame

currentTurtleScale :: EqState a -> Int
currentTurtleScale state =
  case stateFrames state of
    [] -> 1
    frame : _ -> turtleScale $ frameTurtle frame

nextTurtleDepth :: EqState a -> Int
nextTurtleDepth state =
  case stateFrames state of
    [] -> 1
    frame : _ -> let turtle = frameTurtle frame
                 in turtleDepth turtle + turtleScale turtle

Here is a naive version of the algorithm used by the test program.

{-# LANGUAGE RecordWildCards #-}

module NaiveEqualTree (Tree(..),
                       naiveEqual)
       where

import Data.Array.IO (IOArray)
import Data.Array.MArray (readArray,
                          getBounds)

import EqualTree (Tree(..),
                  Node)

data Frame a = Frame { frameNodes :: !(Node a, Node a),
                       frameSiblings :: !(Maybe (Siblings a)) }

data Siblings a = Siblings { siblingNodes :: !(Node a, Node a),
                             siblingIndex :: !Int }

data NodeComparison = EqualNodes | NotEqualNodes | HalfEqualNodes

naiveEqual :: Eq a => Tree a -> Tree a -> IO Bool
naiveEqual tree0 tree1 = ascend [] tree0 tree1 Nothing

ascend :: Eq a => [Frame a] -> Tree a -> Tree a -> Maybe (Siblings a) -> IO Bool
ascend state (Value value0) (Value value1) siblings =
  if value0 == value1
  then descend state siblings
  else return False
ascend state (Node node0) (Node node1) siblings =
  case testNodes (node0, node1) state of
    EqualNodes -> descend state siblings
    HalfEqualNodes -> return False
    NotEqualNodes -> do
      (_, bound0) <- getBounds node0
      (_, bound1) <- getBounds node1
      if bound0 == bound1
        then do
          let frame = Frame { frameNodes = (node0, node1),
                              frameSiblings = siblings }
              state' = frame : state
          tree0 <- readArray node0 0
          tree1 <- readArray node1 0
          if bound0 > 0
            then let siblings = Siblings { siblingNodes = (node0, node1),
                                           siblingIndex = 1 }
                 in frame `seq` ascend state' tree0 tree1 (Just siblings)
            else frame `seq` ascend state' tree0 tree1 Nothing
        else return False
ascend _ _ _ _ = return False

descend :: Eq a => [Frame a] -> Maybe (Siblings a) -> IO Bool
descend state Nothing =
  case state of
    [] -> return True
    frame : rest -> descend rest (frameSiblings frame)
descend state (Just Siblings{..}) = do
  let (node0, node1) = siblingNodes
  (_, bound) <- getBounds node0
  tree0 <- readArray node0 siblingIndex
  tree1 <- readArray node1 siblingIndex
  if siblingIndex < bound
    then let siblings' = Siblings { siblingNodes = (node0, node1),
                                    siblingIndex = siblingIndex + 1 }
         in ascend state tree0 tree1 (Just siblings')
    else ascend state tree0 tree1 Nothing

testNodes :: (Node a, Node a) -> [Frame a] -> NodeComparison
testNodes _ [] = NotEqualNodes
testNodes nodes (frame : rest) =
  case equalNodes nodes (frameNodes frame) of
    NotEqualNodes -> testNodes nodes rest
    HalfEqualNodes -> HalfEqualNodes
    EqualNodes -> EqualNodes

equalNodes :: (Node a, Node a) -> (Node a, Node a) -> NodeComparison
equalNodes (node0, node1) (node2, node3) =
  if node0 == node2
  then if node1 == node3
       then EqualNodes
       else HalfEqualNodes
  else if node1 == node3
       then HalfEqualNodes
       else NotEqualNodes

Here is the code of the test program. Note that this will occasionally fail on the not-equals test because it is designed to generate sets of nodes with a significant degree of commonality, as controlled by commonPortionRange.

{-# LANGUAGE TupleSections #-}

module Main where

import Data.Array (Array,
                   listArray,
                   bounds,
                   (!))
import Data.Array.IO (IOArray)
import Data.Array.MArray (writeArray,
                          newArray_)
import Control.Monad (forM_,
                      mapM,
                      mapM_,
                      liftM,
                      foldM)
import Control.Exception (SomeException,
                          catch)
import System.Random (StdGen,
                      newStdGen,
                      random,
                      randomR,
                      split)
import Prelude hiding (catch)

import EqualTree (Tree(..),
                  equal)
import NaiveEqualTree (naiveEqual)

leafChance :: Double
leafChance = 0.5

valueCount :: Int
valueCount = 1

maxNodeCount :: Int
maxNodeCount = 1024

commonPortionRange :: (Double, Double)
commonPortionRange = (0.8, 0.9)

commonRootChance :: Double
commonRootChance = 0.5

nodeSizeRange :: (Int, Int)
nodeSizeRange = (2, 5)

testCount :: Int
testCount = 1000

makeMapping :: Int -> (Int, Int) -> Int -> StdGen ->
               ([Either Int Int], StdGen)
makeMapping values range nodes gen =
  let (count, gen') = randomR range gen
  in makeMapping' 0 [] count gen'
  where makeMapping' index mapping count gen
          | index >= count = (mapping, gen)
          | otherwise =
            let (chance, gen0) = random gen
                (slot, gen2) =
                  if chance <= leafChance
                  then let (value, gen1) = randomR (0, values - 1) gen0
                       in (Left value, gen1)
                  else let (nodeIndex, gen1) = randomR (0, nodes - 1) gen0
                       in (Right nodeIndex, gen1)
            in makeMapping' (index + 1) (slot : mapping) count gen2

makeMappings :: Int -> Int -> (Int, Int) -> StdGen ->
                ([[Either Int Int]], StdGen)
makeMappings size values range gen =
  let (size', gen') = randomR (1, size) gen
  in makeMappings' 0 size' [] gen'
  where makeMappings' index size mappings gen
          | index >= size = (mappings, gen)
          | otherwise =
            let (mapping, gen') = makeMapping values range size gen
            in makeMappings' (index + 1) size (mapping : mappings) gen'

makeMappingsPair :: Int -> (Double, Double) -> Int -> (Int, Int) -> StdGen ->
                    ([[Either Int Int]], [[Either Int Int]], StdGen)
makeMappingsPair size commonPortionRange values range gen =
  let (size', gen0) = randomR (2, size) gen
      (commonPortion, gen1) = randomR commonPortionRange gen0
      size0 = 1 + (floor $ fromIntegral size' * commonPortion)
      size1 = size' - size0
      (mappings, gen2) = makeMappingsPair' 0 size0 size' [] gen1
      (mappings0, gen3) = makeMappingsPair' 0 size1 size' [] gen2
      (mappings1, gen4) = makeMappingsPair' 0 size1 size' [] gen3
      (commonRootValue, gen5) = random gen4
  in if commonRootValue < commonRootChance
     then (mappings ++ mappings0, mappings ++ mappings1, gen5)
     else (mappings0 ++ mappings, mappings1 ++ mappings, gen5)
  where makeMappingsPair' index size size' mappings gen
          | index >= size = (mappings, gen)
          | otherwise =
            let (mapping, gen') = makeMapping values range size' gen
            in makeMappingsPair' (index + 1) size size' (mapping : mappings)
               gen'

populateNode :: IOArray Int (Tree a) -> Array Int (IOArray Int (Tree a)) ->
                [Either a Int] -> IO ()
populateNode node nodes mapping =
  mapM_ (uncurry populateSlot) (zip [0..] mapping)
  where populateSlot index (Left value) =
          writeArray node index $ Value value
        populateSlot index (Right nodeIndex) =
          writeArray node index . Node $ nodes ! nodeIndex

makeTree :: [[Either a Int]] -> IO (Tree a)
makeTree mappings = do
  let size = length mappings
  nodes <- liftM (listArray (0, size - 1)) $ mapM makeNode mappings
  mapM_ (\(index, mapping) -> populateNode (nodes ! index) nodes mapping)
    (zip [0..] mappings)
  return . Node $ nodes ! 0
  where makeNode mapping = newArray_ (0, length mapping - 1)

testEqual :: StdGen -> IO (Bool, StdGen)
testEqual gen = do
  let (mappings, gen0) =
        makeMappings maxNodeCount valueCount nodeSizeRange gen
  tree0 <- makeTree mappings
  tree1 <- makeTree mappings
  catch (liftM (, gen0) $ equal tree0 tree1) $ \e -> do
    putStrLn $ show (e :: SomeException)
    return (False, gen0)

testNotEqual :: StdGen -> IO (Bool, Bool, StdGen)
testNotEqual gen = do
  let (mappings0, mappings1, gen0) =
        makeMappingsPair maxNodeCount commonPortionRange valueCount
        nodeSizeRange gen
  tree0 <- makeTree mappings0
  tree1 <- makeTree mappings1
  test <- naiveEqual tree0 tree1
  if not test
    then
      catch (testNotEqual' tree0 tree1 mappings0 mappings1 gen0) $ \e -> do
        putStrLn $ show (e :: SomeException)
        return (False, False, gen0)
    else return (True, True, gen0)
  where testNotEqual' tree0 tree1 mappings0 mappings1 gen0 = do
          test <- equal tree0 tree1
          if test
            then do
              putStrLn "Match failure: "
              putStrLn "Mappings 0: "
              mapM (putStrLn . show) $ zip [0..] mappings0
              putStrLn "Mappings 1: "
              mapM (putStrLn . show) $ zip [0..] mappings1
              return (False, False, gen0)
            else return (True, False, gen0)

doTestEqual :: (StdGen, Int) -> Int -> IO (StdGen, Int)
doTestEqual (gen, successCount) _ = do
  (success, gen') <- testEqual gen
  return (gen', successCount + (if success then 1 else 0))

doTestNotEqual :: (StdGen, Int, Int) -> Int -> IO (StdGen, Int, Int)
doTestNotEqual (gen, successCount, excludeCount) _ = do
  (success, exclude, gen') <- testNotEqual gen
  return (gen', successCount + (if success then 1 else 0),
          excludeCount + (if exclude then 1 else 0))

main :: IO ()
main = do
  gen <- newStdGen
  (gen0, equalSuccessCount) <- foldM doTestEqual (gen, 0) [1..testCount]
  putStrLn $ show equalSuccessCount ++ " out of " ++ show testCount ++
    " tests for equality passed"
  (_, notEqualSuccessCount, excludeCount) <-
    foldM doTestNotEqual (gen0, 0, 0) [1..testCount]
  putStrLn $ show notEqualSuccessCount ++ " out of " ++ show testCount ++
    " tests for inequality passed (with " ++ show excludeCount ++ " excluded)"
like image 1
12 revs Avatar answered Oct 13 '22 19:10

12 revs