The question is to find out sum of distances between every two nodes of BinarySearchTree given that every parent-child pair is separated by unit distance. It is to be calculated after every insertion.
ex:
->first node is inserted..
(root)
total sum=0;
->left and right node are inserted
(root)
/ \
(left) (right)
total sum = distance(root,left)+distance(root,right)+distance(left,right);
= 1 + 1 + 2
= 4
and so on.....
Solutions I came up with:
Brute-force. Steps:
O(n)
.O(nC2)_times_O(log(n))=O(n2log(n))
distance between them using Lowest Common Ancestor Method and add them up.Overall Complexity: -O(n2log(n))
.
O(nlog(n))
. Steps:-
O(n)
.O(nlog(n))
. the remaining nodes.Overall Complexity: -O(nlog(n))
.
Now the question is "is there any solution exists of order of O(n)
??
Yes, you can find the sum distance of the whole tree between every two node by DP in O(n). Briefly, you should know 3 things:
cnt[i] is the node count of the ith-node's sub-tree
dis[i] is the sum distance of every ith-node subtree's node to i-th node
ret[i] is the sum distance of the ith-node subtree between every two node
notice that ret[root]
is answer of the problem, so just calculate ret[i]
right and the problem will be done...
How to calculate ret[i]
? Need the help of cnt[i]
and dis[i]
and solve it recursively.
The key problem is:
Given ret[left] ret[right] dis[left] dis[right] cnt[left] cnt[right] to cal ret[node] dis[node] cnt[node].
(node)
/ \
(left-subtree) (right subtree)
/ \
...(node x_i) ... ...(node y_i)...
important:x_i is the any node in left-subtree(not leaf!)
and y_i is the any node in right-subtree(not leaf either!).
cnt[node]
is easy,just equals cnt[left] + cnt[right] + 1
dis[node]
is not so hard, equals dis[left] + dis[right] + cnt[left] + cnt[right]
. reason: sigma(xi->left) is dis[left]
, so sigma(xi->node) is dis[left] + cnt[left]
.
ret[node]
equal three part:
ret[left] + ret[right]
.dis[node]
.equals sigma(xi -> node -> yj), fixed xi, then we get cnt[left]*distance(xi,node) + sigma(node->yj), then cnt[left]*distance(xi,node) + sigma(node->left->yj),
and it is cnt[left]*distance(x_i,node) + cnt[left] + dis[left]
.
Sum up xi: cnt[left]*(cnt[right]+dis[right]) + cnt[right]*(cnt[left] + dis[left])
, then it is 2*cnt[left]*cnt[right] + dis[left]*cnt[right] + dis[right]*cnt[left]
.
Sum these three parts and we get ret[i]
. Do it recursively, we will get ret[root]
.
My code:
import java.util.Arrays;
public class BSTDistance {
int[] left;
int[] right;
int[] cnt;
int[] ret;
int[] dis;
int nNode;
public BSTDistance(int n) {// n is the number of node
left = new int[n];
right = new int[n];
cnt = new int[n];
ret = new int[n];
dis = new int[n];
Arrays.fill(left,-1);
Arrays.fill(right,-1);
nNode = n;
}
void add(int a, int b)
{
if (left[b] == -1)
{
left[b] = a;
}
else
{
right[b] = a;
}
}
int cal()
{
_cal(0);//assume root's idx is 0
return ret[0];
}
void _cal(int idx)
{
if (left[idx] == -1 && right[idx] == -1)
{
cnt[idx] = 1;
dis[idx] = 0;
ret[idx] = 0;
}
else if (left[idx] != -1 && right[idx] == -1)
{
_cal(left[idx]);
cnt[idx] = cnt[left[idx]] + 1;
dis[idx] = dis[left[idx]] + cnt[left[idx]];
ret[idx] = ret[left[idx]] + dis[idx];
}//left[idx] == -1 and right[idx] != -1 is impossible, guarranted by add(int,int)
else
{
_cal(left[idx]);
_cal(right[idx]);
cnt[idx] = cnt[left[idx]] + 1 + cnt[right[idx]];
dis[idx] = dis[left[idx]] + dis[right[idx]] + cnt[left[idx]] + cnt[right[idx]];
ret[idx] = dis[idx] + ret[left[idx]] + ret[right[idx]] + 2*cnt[left[idx]]*cnt[right[idx]] + dis[left[idx]]*cnt[right[idx]] + dis[right[idx]]*cnt[left[idx]];
}
}
public static void main(String[] args)
{
BSTDistance bst1 = new BSTDistance(3);
bst1.add(1, 0);
bst1.add(2, 0);
// (0)
// / \
//(1) (2)
System.out.println(bst1.cal());
BSTDistance bst2 = new BSTDistance(5);
bst2.add(1, 0);
bst2.add(2, 0);
bst2.add(3, 1);
bst2.add(4, 1);
// (0)
// / \
// (1) (2)
// / \
// (3) (4)
//0 -> 1:1
//0 -> 2:1
//0 -> 3:2
//0 -> 4:2
//1 -> 2:2
//1 -> 3:1
//1 -> 4:1
//2 -> 3:3
//2 -> 4:3
//3 -> 4:2
//2*4+3*2+1*4=18
System.out.println(bst2.cal());
}
}
output:
4
18
For the convenience(of readers to understand my solution), I paste the value of cnt[],dis[] and ret[]
after bst2.cal()
is called:
cnt[] 5 3 1 1 1
dis[] 6 2 0 0 0
ret[] 18 4 0 0 0
PS: It's the solution from UESTC_elfness, it's a simple problem for him , and I'm sayakiss, the problem is not so hard for me..
So you can trust us...
We can do this by traverse the tree two times.
First, we need three array
int []left
which stored the sum of the distance of the left sub tree.
int []right
which stored the sum of the distance of the right sub tree.
int []up
which stored the sum of the distance of the parent tree (without the current sub tree).
So, first traversal, for each node, we calculate the left and the right distance. If the node is a leaf, simply return 0, if not, we can have this formula:
int cal(Node node){
int left = cal(node.left);
int right = cal(node.right);
left[node.index] = left;
right[node.index] = right;
//Depend on the current node have left or right node, we add 0,1 or 2 to the final result
int add = (node.left != null && node.right != null)? 2 : node.left != null ? 1 : node.right != null ? 1 : 0;
return left + right + add;
}
Then for the second traversal, we need to add to each node, the total distance from his parent.
1
/ \
2 3
/ \
4 5
For example, for node 1 (root), the total distance is left[1] + right[1] + 2
, up[1] = 0
; (we add 2 as the root has both left and right sub tree, the exact formula for it is:
int add = 0;
if (node.left != null)
add++;
if(node.right != null)
add++;
For node 2 , the total distance is left[2] + right[2] + add + up[1] + right[1] + 1 + addRight
, up[2] = up[1] + right[1] + addRight
. The reason there is a 1
at the end of the formula is because there is an edge from the current node to his parent, so we need to add 1
. Now, I denote the additional distance for the current node is add
, additional distance if there is a left subtree in parent node is addLeft
and similarly addRight
for right subtree.
For node 3, the total distance is up[1] + left[1] + 1 + addLeft
, up[3] = up[1] + left[1] + addLeft
;
For node 4, the total distance is up[2] + right[2] + 1 + addRight
, up[4] = up[2] + right[2] + addRight
;
So depend on the current node is a left or right node, we update the up
accordingly.
The time complexity is O(n)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With