I am trying to solve a variation of the Maximum Path Sum in a Binary Tree problem where some nodes in the tree are colored red. The path sum is only valid if:
Given this constraint, how do I compute the maximum sum path in the binary tree?
Consider this tree where (R) represents red nodes:
10(R)
/ \
-2 7(R)
/ \ \
8(R) -4 6
/
-1(R)
The standard approach for Maximum Path Sum uses DFS with recursion while maintaining a global max. I modified it to only update the global max when encountering a red-to-red path, but I am struggling to properly track valid paths and backtrack correctly.
class Node:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.red = False
class Solution:
def solve(self, root):
ans = float("-inf")
def dfs(node):
if not node:
return [0, False]
left, left_red = dfs(node.left)
right, right_red = dfs(node.right)
# update the global ans variable based on wether current node is red or not
nonlocal ans
if node.red:
if left_red:
ans = max(ans, node.val + left)
if right_red:
ans = max(ans, node.val + right)
if left_red and right_red:
ans = max(ans, node.val + left + right)
else:
if left_red and right_red:
ans = max(ans, node.val + left + right)
# return the single best rising path from this node
if node.red:
local_max = float("-inf")
if left_red:
local_max = max(local_max, node.val + left, node.val)
if right_red:
local_max = max(local_max, node.val + right, node.val)
return [max(local_max, node.val), node.red]
else:
local_max = float("-inf")
if left_red:
local_max = max(local_max, node.val + left, node.val)
if right_red:
local_max = max(local_max, node.val + right, node.val)
return [local_max, left_red or right_red]
dfs(root)
return ans
soln = Solution()
root = Node(10)
root.left = Node(-5)
root.right = Node(20)
root.left.left = Node(4)
root.left.right = Node(3)
root.right.left = Node(1)
root.right.right = Node(6)
root.right.right.red = True
root.right.left.left = Node(-10)
root.right.left.left.red = True
print(soln.solve(root))
This fails for test cases where a red node can be a leaf.
10
/ \
-5 20
/ \ / \
4 3 1 6(R)
/
-10(R)
The actual answer to this should be
-10 -> 1 -> 20 -> 6 = 17
But the output is 27
Is this (using dfs) even the right approach? Or do I have to reframe the tree as a graph and then do a BFS from each red node to compute the distance from red to red node?
I think there might be a small bug in how I am returning the rising path but I am not able to point it out correctly.
The code given above has a subtle bug in the way the rising path is being returned from a non red node.
While returning the best rising path from a node that is not red, we can't compare just the current node. So that needs to be removed from the comparison.
Here is the updated code that works--
class Node:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.red = False
class Solution:
def solve(self, root):
ans = float("-inf")
def dfs(node):
if not node:
return [0, False]
left, left_red = dfs(node.left)
right, right_red = dfs(node.right)
# update the global ans variable based on wether current node is red or not
nonlocal ans
if node.red:
if left_red:
ans = max(ans, node.val + left)
if right_red:
ans = max(ans, node.val + right)
if left_red and right_red:
ans = max(ans, node.val + left + right)
else:
if left_red and right_red:
ans = max(ans, node.val + left + right)
# return the single best rising path from this node
if node.red:
local_max = float("-inf")
if left_red:
local_max = max(local_max, node.val + left, node.val)
if right_red:
local_max = max(local_max, node.val + right, node.val)
return [max(local_max, node.val), node.red]
else:
local_max = float("-inf")
if left_red:
local_max = max(local_max, node.val + left)
if right_red:
local_max = max(local_max, node.val + right)
return [local_max, left_red or right_red]
dfs(root)
return ans
soln = Solution()
root = Node(10)
root.left = Node(-5)
root.right = Node(20)
root.left.left = Node(4)
root.left.right = Node(3)
root.right.left = Node(1)
root.right.right = Node(6)
root.right.right.red = True
root.right.left.left = Node(-10)
root.right.left.left.red = True
print(soln.solve(root))