Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding if a number is equal to sum of 2 nodes in a binary search tree

Here is my code that for this. I am traversing the whole tree and then doing a find on each node. find() takes O(log n), and so the whole program takes O(n log n) time.

Is there a better way to implement this program? I am not just talking of better in terms of time complexity but in general as well. How best to implement this?

public boolean searchNum(BinTreeNode node, int num) {
    //validate the input

    if (node == null) {
        return false;
    }
    // terminal case for recursion

    int result = num - node.item;
    //I have a separate find() which finds if the key is in the tree
    if (find(result)) {
        return true;
    }
    return seachNum(node.leftChild, num) || searchNum(node.rightChilde, num);

}

public boolean find(int key) {

    BinTreeNode node = findHelper(key, root);
    if (node == null) {
        return false;
    } else {
        return true;
    }
}


private BinTreeNode findHelper(int key, BinTreeNode node) {
    if (node == null) {
        return null;
    }
    if (key == node.item) {
        return node;
    } else if (key < node.item) {
        return findHelper(key, node.leftChild);
    } else {
        return findHelper(key, node.rightChild);
    }
}
like image 272
Prasanna Avatar asked Dec 25 '22 20:12

Prasanna


2 Answers

Finding two nodes in binary search tree sum to some value can be done in the similar way of finding two elements in a sorted array that sums to the value.

In the case with an array sorted from small to large, you keep two pointers, one start from beginning, one start from the end. If the sum of the two elements pointed by the pointers is larger than the target, you move the right pointer to left by one, if the sum is smaller than target, you move the left pointer to right by one. Eventually the two pointer will either points to two elements that sum to the target value, or meet in the middle.

boolean searchNumArray(int[] arr, int num) {
    int left = 0;
    int right = arr.length - 1;
    while (left < right) {
        int sum = arr[left] + arr[right];
        if (sum == num) {
          return true;
        } else if (sum > num) {
          right--;
        } else {
          left++;
        }
    }
    return false;
} 

If you do an in-order traversal of the binary search tree, it becomes a sorted array. So you can apply the same idea on binary search tree.

The following code do iterative in-order traversal from both directions. Stack is being used for the traversal, so the time complexity is O(n) and space complexity is O(h), where h is the height of the binary tree.

class BinTreeIterator implements Iterator<BinTreeNode> {
    Stack<BinTreeNode> stack;
    boolean leftToRight;

    public boolean hasNext() {
        return !stack.empty();
    }

    public BinTreeNode next() {
        return stack.peek();
    }

    public void remove() {
        BinTreeNode node = stack.pop();
        if (leftToRight) {
            node = node.rightChild;
            while (node.rightChild != null) {
                stack.push(node);
                node = node.rightChild;
            }
        } else {
            node = node.leftChild;
            while (node.leftChild != null) {
                stack.push(node);
                node = node.leftChild;
            }
        }
    }

    public BinTreeIterator(BinTreeNode node, boolean leftToRight) {
        stack = new Stack<BinTreeNode>();
        this.leftChildToRight = leftToRight;

        if (leftToRight) {
            while (node != null) {
                stack.push(node);
                node = node.leftChild;
            }
        } else {
            while (node != null) {
                stack.push(node);
                node = node.rightChild;
            }
        }            
    }
}



public static boolean searchNumBinTree(BinTreeNode node, int num) {
    if (node == null)
        return false;

    BinTreeIterator leftIter = new BinTreeIterator(node, true);
    BinTreeIterator rightIter = new BinTreeIterator(node, false);

    while (leftIter.hasNext() && rightIter.hasNext()) {
        BinTreeNode left = leftIter.next();
        BinTreeNode right = rightIter.next();
        int sum = left.item + right.item;
        if (sum == num) {
            return true;
        } else if (sum > num) {
            rightIter.remove();
            if (!rightIter.hasNext() || rightIter.next() == left) {
                return false;
            }
        } else {
            leftIter.remove();
            if (!leftIter.hasNext() || leftIter.next() == right) {
                return false;
            }
        }
    }

    return false;
}
like image 62
Chen Pang Avatar answered Dec 28 '22 10:12

Chen Pang


Chen Pang has already given a perfect answer. However, I was trying the same problem today and I could come up with the following solution. Posting it here as it might help some one.

The idea is same as that of earlier solution, just that I am doing it with two stacks - one following the inorder(stack1) and another following reverse - inorder order(stack2). Once we reach the left-most and the right-most node in a BST, we can start comparing them together.

If the sum is less than the required value, pop out from stack1, else pop from stack2. Following is java implementation of the same:

public int sum2(TreeNode A, int B) {
    Stack<TreeNode> stack1 = new Stack<>();
    Stack<TreeNode> stack2 = new Stack<>();
    TreeNode cur1 = A;
    TreeNode cur2 = A;

    while (!stack1.isEmpty() || !stack2.isEmpty() ||
            cur1 != null || cur2 != null) {
        if (cur1 != null || cur2 != null) {
            if (cur1 != null) {
                stack1.push(cur1);
                cur1 = cur1.left;
            }

            if (cur2 != null) {
                stack2.push(cur2);
                cur2 = cur2.right;
            }
        } else {
            int val1 = stack1.peek().val;
            int val2 = stack2.peek().val;

            // need to break out of here
            if (stack1.peek() == stack2.peek()) break;

            if (val1 +  val2 == B) return 1;

            if (val1 + val2 < B) {
                cur1 = stack1.pop();
                cur1 = cur1.right;
            } else {
                cur2 = stack2.pop();
                cur2 = cur2.left;
            }
        }
    }

    return 0;
}
like image 29
pankaj Avatar answered Dec 28 '22 09:12

pankaj