Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Algorithm- Sum of distances between every two nodes of a Binary Search Tree in O(n)?

The question is to find out sum of distances between every two nodes of BinarySearchTree given that every parent-child pair is separated by unit distance. It is to be calculated after every insertion.

ex:

 ->first node is inserted..

      (root)

   total sum=0;

->left and right node are inserted

      (root)
      /    \
  (left)   (right)

   total sum = distance(root,left)+distance(root,right)+distance(left,right);
             =        1           +         1          +         2
             =     4

and so on.....

Solutions I came up with:

  1. Brute-force. Steps:

    1. perform a DFS and track all the nodes : O(n).
    2. Select every two nodes and calculate : O(nC2)_times_O(log(n))=O(n2log(n)) distance between them using Lowest Common Ancestor Method and add them up.

    Overall Complexity: -O(n2log(n)).

  2. O(nlog(n)). Steps:-

    1. Before insertion perform DFS and track all nodes : O(n).
    2. Calculate distance between the inserted node and : O(nlog(n)). the remaining nodes.
    3. Add the existing sum with the sum calculated in step 2

    Overall Complexity: -O(nlog(n)).

Now the question is "is there any solution exists of order of O(n)??

like image 379
technical_yogi Avatar asked Jul 17 '14 01:07

technical_yogi


2 Answers

Yes, you can find the sum distance of the whole tree between every two node by DP in O(n). Briefly, you should know 3 things:

cnt[i] is the node count of the ith-node's sub-tree
dis[i] is the sum distance of every ith-node subtree's node to i-th node
ret[i] is the sum distance of the ith-node subtree between every two node

notice that ret[root] is answer of the problem, so just calculate ret[i] right and the problem will be done... How to calculate ret[i]? Need the help of cnt[i] and dis[i] and solve it recursively. The key problem is:

Given ret[left] ret[right] dis[left] dis[right] cnt[left] cnt[right] to cal ret[node] dis[node] cnt[node].

              (node)
          /             \
    (left-subtree) (right subtree)
      /                   \
...(node x_i) ...   ...(node y_i)...
important:x_i is the any node in left-subtree(not leaf!) 
and y_i is the any node in right-subtree(not leaf either!).

cnt[node] is easy,just equals cnt[left] + cnt[right] + 1

dis[node] is not so hard, equals dis[left] + dis[right] + cnt[left] + cnt[right]. reason: sigma(xi->left) is dis[left], so sigma(xi->node) is dis[left] + cnt[left].

ret[node] equal three part:

  1. xi -> xj and yi -> yj, equals ret[left] + ret[right].
  2. xi -> node and yi -> node, equals dis[node].
  3. xi -> yj:

equals sigma(xi -> node -> yj), fixed xi, then we get cnt[left]*distance(xi,node) + sigma(node->yj), then cnt[left]*distance(xi,node) + sigma(node->left->yj),

and it is cnt[left]*distance(x_i,node) + cnt[left] + dis[left].

Sum up xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left]), then it is 2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left].

Sum these three parts and we get ret[i]. Do it recursively, we will get ret[root].

My code:

import java.util.Arrays;

public class BSTDistance {
    int[] left;
    int[] right;
    int[] cnt;
    int[] ret;
    int[] dis;
    int nNode;
    public BSTDistance(int n) {// n is the number of node
        left = new int[n];
        right = new int[n];
        cnt = new int[n];
        ret = new int[n];
        dis = new int[n];
        Arrays.fill(left,-1);
        Arrays.fill(right,-1);
        nNode = n;
    }
    void add(int a, int b)
    {
        if (left[b] == -1)
        {
            left[b] = a;
        }
        else
        {
            right[b] = a;
        }
    }
    int cal()
    {
        _cal(0);//assume root's idx is 0
        return ret[0];
    }
    void _cal(int idx)
    {
        if (left[idx] == -1 && right[idx] == -1)
        {
            cnt[idx] = 1;
            dis[idx] = 0;
            ret[idx] = 0;
        }
        else if (left[idx] != -1  && right[idx] == -1)
        {
            _cal(left[idx]);
            cnt[idx] = cnt[left[idx]] + 1;
            dis[idx] = dis[left[idx]] + cnt[left[idx]];
            ret[idx] = ret[left[idx]] + dis[idx];
        }//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)  
        else 
        {
            _cal(left[idx]);
            _cal(right[idx]);
            cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
            dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
            ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
        }
    }
    public static void main(String[] args)
    {
        BSTDistance bst1 = new BSTDistance(3);
        bst1.add(1, 0);
        bst1.add(2, 0);
        //   (0)
        //  /   \
        //(1)   (2)
        System.out.println(bst1.cal());
        BSTDistance bst2 = new BSTDistance(5);
        bst2.add(1, 0);
        bst2.add(2, 0);
        bst2.add(3, 1);
        bst2.add(4, 1);
        //       (0)
        //      /   \
        //    (1)   (2)
        //   /   \
        // (3)   (4)
        //0 -> 1:1
        //0 -> 2:1
        //0 -> 3:2
        //0 -> 4:2
        //1 -> 2:2
        //1 -> 3:1
        //1 -> 4:1
        //2 -> 3:3
        //2 -> 4:3
        //3 -> 4:2
        //2*4+3*2+1*4=18
        System.out.println(bst2.cal());
    }
}

output:

4
18

For the convenience(of readers to understand my solution), I paste the value of cnt[],dis[] and ret[] after bst2.cal() is called:

cnt[] 5 3 1 1 1 
dis[] 6 2 0 0 0
ret[] 18 4 0 0 0 

PS: It's the solution from UESTC_elfness, it's a simple problem for him , and I'm sayakiss, the problem is not so hard for me..

So you can trust us...

like image 26
Sayakiss Avatar answered Oct 15 '22 05:10

Sayakiss


We can do this by traverse the tree two times.

First, we need three array

int []left which stored the sum of the distance of the left sub tree.

int []right which stored the sum of the distance of the right sub tree.

int []up which stored the sum of the distance of the parent tree (without the current sub tree).

So, first traversal, for each node, we calculate the left and the right distance. If the node is a leaf, simply return 0, if not, we can have this formula:

int cal(Node node){
    int left = cal(node.left);
    int right = cal(node.right);
    left[node.index] = left;
    right[node.index] = right;
    //Depend on the current node have left or right node, we add 0,1 or 2 to the final result
    int add = (node.left != null && node.right != null)? 2 : node.left != null ? 1 : node.right != null ? 1 : 0;
    return left + right + add;
}

Then for the second traversal, we need to add to each node, the total distance from his parent.

             1
            / \
           2   3
          / \
         4   5

For example, for node 1 (root), the total distance is left[1] + right[1] + 2, up[1] = 0; (we add 2 as the root has both left and right sub tree, the exact formula for it is:

int add = 0; 
if (node.left != null) 
    add++;
if(node.right != null)
    add++;

For node 2 , the total distance is left[2] + right[2] + add + up[1] + right[1] + 1 + addRight, up[2] = up[1] + right[1] + addRight. The reason there is a 1 at the end of the formula is because there is an edge from the current node to his parent, so we need to add 1. Now, I denote the additional distance for the current node is add, additional distance if there is a left subtree in parent node is addLeft and similarly addRight for right subtree.

For node 3, the total distance is up[1] + left[1] + 1 + addLeft, up[3] = up[1] + left[1] + addLeft;

For node 4, the total distance is up[2] + right[2] + 1 + addRight, up[4] = up[2] + right[2] + addRight;

So depend on the current node is a left or right node, we update the up accordingly.

The time complexity is O(n)

like image 198
Pham Trung Avatar answered Oct 15 '22 05:10

Pham Trung