Do not understand the solution for the Binary Tree Maximum Path Sum problem

The website GeeksforGeeks has presented a solution for the problem of Maximum path sum for a binary tree. The question is as follows:

Given a binary tree, find the maximum path sum. The path may start and end at any node in the tree.

The core of the solution is as follows:

int findMaxUtil(Node node, Res res) 
    if (node == null) 
        return 0; 
    // l and r store maximum path sum going through left and 
    // right child of root respectively 
    int l = findMaxUtil(node.left, res); 
    int r = findMaxUtil(node.right, res); 
    // Max path for parent call of root. This path must 
    // include at-most one child of root 
    int max_single = Math.max(Math.max(l, r) + node.data, 
    // Max Top represents the sum when the Node under 
    // consideration is the root of the maxsum path and no 
    // ancestors of root are there in max sum path 
    int max_top = Math.max(max_single, l + r + node.data); 
    // Store the Maximum Result. 
    res.val = Math.max(res.val, max_top); 
    return max_single; 

int findMaxSum() { 
    return findMaxSum(root); 
// Returns maximum path sum in tree with given root 
int findMaxSum(Node node) { 
    // Initialize result 
    // int res2 = Integer.MIN_VALUE; 
    Res res = new Res(); 
    res.val = Integer.MIN_VALUE; 
    // Compute and return result 
    findMaxUtil(node, res); 
    return res.val; 

Res has the following definition:

 class Res { 
    public int val; 

I am confused about the reasoning behind these lines of code:

int max_single = Math.max(Math.max(l, r) + node.data, node.data);  

int max_top = Math.max(max_single, l + r + node.data); 

res.val = Math.max(res.val, max_top); 

return max_single; 

I believe the code above follows this logic but I do not understand WHY this logic is correct or valid:

For each node there can be four ways that the max path goes through the node:

  1. Node only
  2. Max path through Left Child + Node
  3. Max path through Right Child + Node
  4. Max path through Left Child + Node + Max path through Right Child

In particular, I do not understand why max_single is being returned in the function findMaxUtil when we the variable res.val contains the answer we are interested in. The following reason is given on the website but I do not understand it:

An important thing to note is, root of every subtree need to return maximum path sum such that at most one child of root is involved.

Could someone provide an explanation for this step of the solution?

3 Answers

In particular, I do not understand why max_single is being returned in the function findMaxUtil when we the variable res.val contains the answer we are interested in.

The problem is that findMaxUtil() really does two things: it returns largest sum of the tree that it's applied to, and it updates a variable that keeps track of the largest sum yet encountered. There's a comment to that effect in the original code, but you edited it out in your question, perhaps for brevity:

// This function returns overall maximum path sum in 'res' 
// And returns max path sum going through root. 
int findMaxUtil(Node node, Res res) 

Because Java passes parameters by value, but every object variable in Java implicitly references the actual object, it's easy to miss the fact that the Res that's passed in the res parameter may be changed by this function. And that's exactly what happens in the lines you asked about:

int max_single = Math.max(Math.max(l, r) + node.data, node.data);  

int max_top = Math.max(max_single, l + r + node.data); 

res.val = Math.max(res.val, max_top); 

return max_single;

That first line finds the maximum of the node itself or the node plus the greatest subtree, and that result is the max path sum going through root. Returning that value on the last line is one thing that this function does. The second and third lines look at that value and consider whether either it or the path that includes both children is larger than any previously seen path, and if so, it updates res, which is the other thing this function does. Keep in mind that res is some object that exists outside the method, so changes to it persist until the recursion stops and findMaxSum(Node), which started the whole thing, returns the res.val.

So, getting back to the question at the top, the reason that the findMaxUtil returns max_single is that it uses that value to recursively determine the max path through each subtree. The value in res is also updated so that findMaxSum(Node) can use it.

You're missing the value of res.val. The algorithm is trying to explore the whole tree, using res.val equal to the maximum path length explored up till then. In each step it iterates recursively across the children and updates res.val with the maximum path length, if higher than the one already present.


Assume your algorithm works with trees with height n. For trees with height n+1 there's a root and 2 sub trees of height n. Also consider that findMaxUtil works fine for i<=n and will return the maximum path, starting with the partial root of the sub trees.

So the maximum path in your tree with height n+1 is calculated as follows

  1. findMaxUtil(subtree1)
  2. findMaxUtil(subtree2)
  3. findmaxUtil(subtree1)+root.data
  4. findmaxUtil(subtree2)+root.data
  5. findmaxUtil(subtree1)+findmaxUtil(subtree2)+root.data
  6. res.val

And finally the result is: findmaxUtil(newTree)=max(items 1:6).

Honestly I think the description on that website is very unclear. I'll try to convince you of the reasoning behind the algorithm as best I can.

We have a binary tree, with values at the nodes:

enter image description here

And we are looking for a path in that tree, a chain of connected nodes.

enter image description here

As it's a directed tree, any nonempty path consists of a lowest-depth node (i.e. the node in the path that is closest to the root of the tree), a path of zero or more nodes descending to the left of the lowest-depth node, and a path of zero or more nodes descending to the right of the lowest-depth node. In particular, somewhere in the tree there is a node that is the lowest-depth node in the maximum path. (Indeed, there might be more than one such path tied for equal value, and they might each have their own distinct lowest-depth node. That's fine. As long as there's at least one, that's what matters.)

(I've used "highest" in the diagram but I mean "lowest-depth". To be clear, any time I use "depth" or "descending" I'm talking about position in the tree. Any time I use "maximum" I'm talking about the value of a node or the sum of values of nodes in a path.)

enter image description here

So if we can find its lowest-depth node, we know the maximum value path is composed of the node itself, a sub-path of zero or more nodes descending from (and including) its left child, and a sub-path of zero or more nodes descending from (and including) its right child. It's a small step to conclude that the left and right descending paths must be the maximum value such descending path on each side. (If this isn't obvious, consider that whatever other path you picked, you could increase the total value by instead picking the maximum value descending path on that side.) If either or both of those paths would have a negative value then we just don't include any nodes at all on the negative side(s).

So we have a separate subproblem - given a subtree, what is the value of the maximum value path descending through its root? Well, it might just be the root itself, if all the paths rooted at its children have negative sum, or if it has no children. Otherwise it is the root plus the maximum value descending path of either of those rooted at its children. This subproblem could easily be answered on its own, but to avoid repeated traversals and redoing work we'll combine them both into one traversal of the tree.

Going back to the main problem, we know that some node is the lowest-depth node in the maximum value path. We're not even particularly concerned with knowing when we visit it - we're just going to recursively visit every node and find the maximum value path that has that path as its lowest-depth node, assured that at some point we will visit the one we want. At each node we calculate both the maximum value path starting at that point and descending within the subtree (max_single) and the maximum value path for which this node is the lowest-depth node in the path (max_top). The latter is found by taking the node and "gluing on" zero, one or both of the maximum descending-only paths through its children. (Since max_single is already the maximum value path descending from zero or one of the children, the only extra thing we need to consider is the path that goes through both children.) By calculating max_top at every node and keeping the largest value found in res.val, we guarantee that we will have found the largest of all values by the time we have finished traversing the tree. At every node we return max_single to use in the parent's calculations. And at the end of the algorithm we just pull out the answer from res.val.

