pythonrecursionbinary-treetree-traversal

How to spread an infected node to its adjacent nodes and eventually to the whole binary tree?


I want to iteratively return the state of the binary tree until the infection is unable to spread to a new node. So how it goes is that the virus will spread to any directly adjacent healthy node and I want to return the state of the tree every after round of infections the virus does. Once every possible node is infected, the algorithm should stop.

What I want to achieve given that "B" is the origin (Example):

enter image description here

Here's what I have so far:

from binarytree import Node, build

def build_tree(data):
    if not data or data[0] == 0:
        return None
    root = Node(data[0])
    nodes = [root]
    for i in range(1, len(data), 2):
        node = nodes.pop(0)
        if data[i] != 0:
            node.left = Node(data[i])
            nodes.append(node.left)
        if i + 1 < len(data) and data[i + 1] != 0:
            node.right = Node(data[i + 1])
            nodes.append(node.right)
    return root
    
def infected(node):
    return f"*{node.value}*"

def search_node(root, value):
    if root is None or root.value == value:
        return root
    left_result = search_node(root.left, value)
    right_result = search_node(root.right, value)
    return left_result if left_result else right_result

def virus(node):
    infected_nodes = set()
    current_round = set([node])

    while current_round:
        next_round = set()

        for current_node in current_round:
            neighbors = [child for child in (current_node.left, current_node.right) if child and child.value != 0]

            for neighbor in neighbors:
                if neighbor not in infected_nodes:
                    infected_nodes.add(neighbor)
                    neighbor.value = infected(neighbor)
                    next_round.add(neighbor)

        if infected_nodes == set(current_round):
           break

        print(tree)
        current_round = next_round

Input:

sample = ["A", "B", "C", "D", "E", 0, "F"]
origin_value = "C"

tree = build_tree(sample)
origin_node = search_node(tree, origin_value)

origin_node.value = infected(origin_node)
print(tree)
virus(origin_node)

Output:

enter image description here

As seen from the output, C only spreads to F and it stops there, where it should have spread to A and to the entirety of the binary tree. What am I doing wrong here?


Solution

  • Here is modified version of the code, using .inorder function to traverse the tree. Also, I've added .infected boolean parameter to each node signaling if the Node is infected or not:

    from binarytree import Node, build
    
    
    def build_tree(data):
        nodes = []
        for i, v in enumerate(data):
            if v == 0:
                nodes.append(None)
            else:
                n = Node(v)
                n.infected = False  # <-- node is initailly *NOT* infected
                nodes.append(n)
    
        for i in range(len(data)):
            left = 2 * i + 1
            right = 2 * i + 2
    
            if left < len(nodes) and nodes[left]:
                nodes[i].left = nodes[left]
    
            if right < len(nodes) and nodes[right]:
                nodes[i].right = nodes[right]
    
        return nodes[0]
    
    
    def is_whole_tree_infected(root):
        return all(n.infected for n in root.inorder)
    
    
    def search_node(root, value):
        for n in root.inorder:
            if n.value == value:
                return n
    
    
    def find_parent(root, node):
        for n in root.inorder:
            if n.left is node:
                return n
            if n.right is node:
                return n
    
    
    def infect(node):
        node.value = f"*{node.value}*"
        node.infected = True
    
    
    def infect_neighbours(root, node):
        out = set()
        parent = find_parent(root, node)
        if parent and not parent.infected:
            out.add(parent)
        if node.left and not node.left.infected:
            out.add(node.left)
        if node.right and not node.right.infected:
            out.add(node.right)
    
        return out
    
    
    def virus(root):
        while not is_whole_tree_infected(root):
            to_infect = set()
    
            for node in root.inorder:
                if node.infected:
                    to_infect |= infect_neighbours(root, node)
    
            for n in to_infect:
                infect(n)
    
            print(root)
    
    
    sample = ["A", "B", "C", "D", "E", 0, "F"]
    origin_value = "C"
    
    tree = build_tree(sample)
    origin_node = search_node(tree, origin_value)
    infect(origin_node)
    
    print(tree)
    virus(tree)
    

    Prints:

    
        __A_
       /    \
      B     *C*
     / \       \
    D   E       F
    
    
        __*A*_
       /      \
      B       *C*_
     / \          \
    D   E         *F*
    
    
         ___*A*_
        /       \
      *B*       *C*_
     /   \          \
    D     E         *F*
    
    
           _____*A*_
          /         \
       _*B*_        *C*_
      /     \           \
    *D*     *E*         *F*