algorithmgraphtreetime-complexity

Maximum Path Sum Between Red Nodes in a Binary Tree


Problem Statement

I am trying to solve a variation of the Maximum Path Sum in a Binary Tree problem where some nodes in the tree are colored red. The path sum is only valid if:

  1. The path starts and ends at a red node.
  2. The path can contain zero or more additional red nodes in between.
  3. The path can include non-red nodes as long as it starts and ends at red nodes.
  4. The path follows parent-child connections (no jumps).

Given this constraint, how do I compute the maximum sum path in the binary tree?

Example

Consider this tree where (R) represents red nodes:

        10(R)
       /     \
    -2       7(R)
    /  \       \
   8(R)  -4     6
         /
       -1(R)

What I Have Tried

The standard approach for Maximum Path Sum uses DFS with recursion while maintaining a global max. I modified it to only update the global max when encountering a red-to-red path, but I am struggling to properly track valid paths and backtrack correctly.

class Node:
    def __init__(self, val):
        self.val = val 
        self.left = None
        self.right = None 
        self.red = False 

class Solution:
    def solve(self, root):
        ans = float("-inf")
    
        def dfs(node):
            if not node:
                return [0, False]
        
            left, left_red = dfs(node.left)
            right, right_red = dfs(node.right)
            
            # update the global ans variable based on wether current node is red or not
            nonlocal ans 
            if node.red:
                if left_red:
                    ans = max(ans, node.val + left)
                if right_red:
                    ans = max(ans, node.val + right)
                if left_red and right_red:
                    ans = max(ans, node.val + left + right)
            else:
                if left_red and right_red:
                    ans = max(ans, node.val + left + right)
            
            # return the single best rising path from this node 
            if node.red:
                local_max = float("-inf")
                if left_red:
                    local_max = max(local_max, node.val + left, node.val)
                if right_red:
                    local_max = max(local_max, node.val + right, node.val)
                
                return [max(local_max, node.val), node.red]
            else:
                local_max = float("-inf")
                if left_red:
                    local_max = max(local_max, node.val + left, node.val)
                if right_red:
                    local_max = max(local_max, node.val + right, node.val)
                
                return [local_max, left_red or right_red]
                
        dfs(root)
        return ans

soln = Solution()

root = Node(10)

root.left = Node(-5)

root.right = Node(20)

root.left.left = Node(4)
root.left.right = Node(3)

root.right.left = Node(1)
root.right.right = Node(6)
root.right.right.red = True

root.right.left.left = Node(-10)
root.right.left.left.red = True

print(soln.solve(root))

This fails for test cases where a red node can be a leaf.

        10 
       /   \
     -5     20  
     / \    / \ 
    4  3   1   6(R)
          /
        -10(R)

The actual answer to this should be

-10 -> 1 -> 20 -> 6 = 17

But the output is 27


Is this (using dfs) even the right approach? Or do I have to reframe the tree as a graph and then do a BFS from each red node to compute the distance from red to red node?

I think there might be a small bug in how I am returning the rising path but I am not able to point it out correctly.

Problem Constraints


Solution

  • The code given above has a subtle bug in the way the rising path is being returned from a non red node.

    While returning the best rising path from a node that is not red, we can't compare just the current node. So that needs to be removed from the comparison.

    Here is the updated code that works--

    class Node:
        def __init__(self, val):
            self.val = val 
            self.left = None
            self.right = None 
            self.red = False 
    
    class Solution:
        def solve(self, root):
            ans = float("-inf")
        
            def dfs(node):
                if not node:
                    return [0, False]
            
                left, left_red = dfs(node.left)
                right, right_red = dfs(node.right)
                
                # update the global ans variable based on wether current node is red or not
                nonlocal ans 
                if node.red:
                    if left_red:
                        ans = max(ans, node.val + left)
                    if right_red:
                        ans = max(ans, node.val + right)
                    if left_red and right_red:
                        ans = max(ans, node.val + left + right)
                else:
                    if left_red and right_red:
                        ans = max(ans, node.val + left + right)
                
                # return the single best rising path from this node 
                if node.red:
                    local_max = float("-inf")
                    if left_red:
                        local_max = max(local_max, node.val + left, node.val)
                    if right_red:
                        local_max = max(local_max, node.val + right, node.val)
                    
                    return [max(local_max, node.val), node.red]
                else:
                    local_max = float("-inf")
                    if left_red:
                        local_max = max(local_max, node.val + left)
                    if right_red:
                        local_max = max(local_max, node.val + right)
                    
                    return [local_max, left_red or right_red]
                    
            dfs(root)
            return ans
    
    soln = Solution()
    
    root = Node(10)
    
    root.left = Node(-5)
    
    root.right = Node(20)
    
    root.left.left = Node(4)
    root.left.right = Node(3)
    
    root.right.left = Node(1)
    root.right.right = Node(6)
    root.right.right.red = True
    
    root.right.left.left = Node(-10)
    root.right.left.left.red = True
    
    print(soln.solve(root))