pythonalgorithmrecursionbinary-treeinfinite-loop

Binary tree recursion clarity needed


I am trying to solve a problem of Find Nodes Distance K. I will post the problem statement.

You're given the root node of a Binary Tree, a target value of a node that's contained in the tree, and a positive integer k. Write a function that returns the values of all the nodes that are exactly distance k from the node with target value. The distance between two nodes is defined as the number of edges that must be traversed to go from one node to the other.

My logic was to work step by step. First, I wrote a function to find the target node. Next, I allocated all Parent nodes, created a dictionary to allocate all parent nodes. And finally, I created a helper function that traversed through all the nodes recursively till it reached either None or k was negative as my base cases. Quickly learnt that I was going to infinite loop because I was traversing three ways

and I would keep visiting a lot of nodes back and forth if I didn't keep a visiting security system.

I wrote a visiting system where I prematurely marked it visited aka Recursive function is called on a node, marks node visited, calls recursion on another node. My idea was that if the node can't visit back the same node, it would stop going to infinite loop. But infinite loop persisted. And I would like an explanation why. Thank you.

My code:


class BinaryTree:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right
        
def findNodesDistanceK(tree, target, k):
    targetNode = findtargetNode(tree, target)
    nodeToParents = {tree: None}
    nodeToParents = assignParentnodes(tree, nodeToParents)
    #for ele in nodeToParents:
    #    if nodeToParents[ele]:
    #        print(ele.value, nodeToParents[ele].value)
    #    else:
    #        print(ele.value, nodeToParents[ele])
    runningList = list()
    visited = list()
    nodes = helper(targetNode, k, nodeToParents, runningList, visited)
    toReturn = list()
    for ele in nodes:
        if ele.value != target:
            toReturn.append(ele.value)
    if target in toReturn:
        toReturn.remove(target)
    setreturn = set(toReturn)
    return list(setreturn)

def helper(node, k, parents, runningList, visited):

    if node is None or k < 0:
        return []
    if node is not None and k == 0:
        runningList.append(node)
        return runningList
        
    if node in visited:
        return runningList
    #visited.append(node)
    clist = list()
    alist = list()
    blist = list()
    if parents[node] is not None:
        clist = helper(parents[node], k-1, parents, runningList, visited)
    #    visited.append(parents[node])
    alist = helper(node.left, k-1, parents, runningList, visited)
    blist = helper(node.right, k-1, parents, runningList, visited)
    runningList.extend(alist)
    runningList.extend(blist)
    runningList.extend(clist)
    visited.append(node)
    return runningList



def findtargetNode(node, target):
    if not node:
        return node
    if node.value == target:
        return node
    leftnode = findtargetNode(node.left, target)
    rightnode = findtargetNode(node.right, target)
    if rightnode is None:
        return leftnode
    return rightnode

def assignParentnodes(node, dictionary):
    if node is None:
        return dictionary
    if node.left:
        dictionary[node.left] = node
    if node.right:
        dictionary[node.right] = node
    assignParentnodes(node.left, dictionary)
    assignParentnodes(node.right, dictionary)
    return dictionary

it doesnt go to recursive problems but when i add it back it goes to recursive problems.

Test case for infinite loop fail:

{
  "nodes": [
    {"id": "1", "left": "2", "right": "3", "value": 1},
    {"id": "2", "left": "4", "right": null, "value": 2},
    {"id": "3", "left": null, "right": "5", "value": 3},
    {"id": "4", "left": "6", "right": null, "value": 4},
    {"id": "5", "left": "7", "right": "8", "value": 5},
    {"id": "6", "left": null, "right": null, "value": 6},
    {"id": "7", "left": null, "right": null, "value": 7},
    {"id": "8", "left": null, "right": null, "value": 8}
  ],
  "root": "1"
}

target:6
k:17

My thought process is that even if I prematurely mark a node as visited before letting it explore other neighbors/children etc first, wouldn't the function return prematurely? Sure, it wouldn't be accurate solution but what causes infinite loop? And what do you suggest fixing? Should i mark it as visited after it explores all of its neighobrs? any help is appreciated.

Testing code


import program
import unittest


class TestProgram(unittest.TestCase):
    def test_case_1(self):
        root = program.BinaryTree(1)
        root.left = program.BinaryTree(2)
        root.right = program.BinaryTree(3)
        root.left.left = program.BinaryTree(4)
        root.left.left.left = program.BinaryTree(6)
        root.right.right = program.BinaryTree(5)
        root.right.right.left = program.BinaryTree(7)
        root.right.right.right = program.BinaryTree(8)
        target = 6
        k = 17
        expected = []
        actual = program.findNodesDistanceK(root, target, k)
        actual.sort()
        self.assertCountEqual(actual, expected)



Solution

  • The problem is that you keep making recursive calls on nodes that have already been visited, allowing a search path to go up and down between parent and child, and to repeat this same up-and-down when coming from a deeper or higher node. This makes the number of search paths exponentially large, and also collects nodes in the result list that are duplicate and not necessarily at a distance 𝑘.

    This is not really an infinitely long process, but both in terms of memory and time it rapidly becomes astronomical. For instance, if you change in the example case k to 8, the expected result is the empty list, but if you keep your program running for long enough it returns [2, 3, 7, 8], which is evidently wrong. For 𝑘 equal to 10 the waiting and memory usage becomes significant, and for 𝑘 equal to 17 the search needs too much time and resources to ever finish normally.

    You can solve this by at least marking a node as visited before making a recursive call, and doing this also before adding a node to the result list.

    So change this:

        if node is not None and k == 0:
            runningList.append(node)  # This should NOT be done when node already visited
            return runningList
        if node in visited:
            return runningList
    
        clist = list()
        alist = list()
        blist = list()
        if parents[node] is not None:
            clist = helper(parents[node], k-1, parents, runningList, visited)
        alist = helper(node.left, k-1, parents, runningList, visited)
        blist = helper(node.right, k-1, parents, runningList, visited)
        visited.append(node)   # <-- this happens too late!
    

    To this:

        if node in visited:  # Do this before recursive calls and before appending
            return runningList
        visited.append(node)  # And then immediately mark as visited
        if node is not None and k == 0:
            runningList.append(node)
            return runningList
    
        clist = list()
        alist = list()
        blist = list()
        if parents[node] is not None:
            clist = helper(parents[node], k-1, parents, runningList, visited)
        alist = helper(node.left, k-1, parents, runningList, visited)
        blist = helper(node.right, k-1, parents, runningList, visited)
    

    This will fix your issue.

    Note that you can shorten your code quite a lot.

    Here is what it could be reduced to:

    def findNodesDistanceK(tree, target, k):
        nodeToParents = assignParentnodes(tree, {tree: None})
        targetNode = next((node for node in nodeToParents if node.value == target), None)
        nodes = helper(targetNode, k, nodeToParents, set())
        return [ele.value for ele in nodes]
    
    
    def helper(node, k, parents, visited):
        if k < 0 or not node or node in visited:
            return
        visited.add(node)
        if k == 0:
            yield node
        else:
            yield from helper(parents[node], k-1, parents, visited)
            yield from helper(node.left, k-1, parents, visited)
            yield from helper(node.right, k-1, parents, visited)
        
    
    def assignParentnodes(node, dictionary):
        if node is None:
            return dictionary
        if node.left:
            dictionary[node.left] = node
        if node.right:
            dictionary[node.right] = node
        assignParentnodes(node.left, dictionary)
        assignParentnodes(node.right, dictionary)
        return dictionary