pythonbinary-treerecursive-datastructures

How to check binary tree symmetry iteratively to avoid stack overflow with deep trees in Python


I have a working recursive solution to check if a binary tree is symmetric. The code works correctly for my test cases, but I'm concerned about potential stack overflow with deep trees since Python has a default recursion limit.

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right

class Solution:
    def isSymmetric(self, root):
        if root is None:
            return True
        return self.compare(root.left, root.right)
    
    def compare(self, nodeA, nodeB):
        if nodeA is None and nodeB is None:
            return True
        if nodeA is None or nodeB is None:
            return False
        if nodeA.val != nodeB.val:
            return False
        return (self.compare(nodeA.left, nodeB.right) and 
                self.compare(nodeA.right, nodeB.left))

These are the test cases that works:

root = TreeNode(1)
root.left = TreeNode(2)
root.right = TreeNode(2)
root.left.left = TreeNode(3)
root.right.right = TreeNode(3)

sol = Solution()
print(sol.isSymmetric(root))  # Output: True

The problem is when testing with a tree that has depth > 1000 nodes, I get:

RecursionError: maximum recursion depth exceeded in comparison

I understand this is O(n) time complexity which is optimal, but the recursive approach uses O(h) space on the call stack where h is tree height.

How can I convert this recursive solution to an iterative one using an explicit stack or queue to avoid Python's recursion limit while maintaining the same comparison logic (mirrored left-right traversal)?

I've tried using a single stack but I'm not sure how to properly track both subtrees simultaneously for the mirrored comparison.

def isSymmetric(self, root):
    if not root:
        return True
    stack = [(root.left, root.right)]

    # Not sure how to proceed from here to compare nodes correctly

What's the correct way to manage the stack to compare pairs of nodes in mirror order?


Solution

  • In terms of time complexity this algorithm is optimal.

    Note that if the input is symmetrical, the recursion depth will be equal to the height of the given binary tree. If that height is large, you may bump into a stack limit. An alternative approach is to work with an explicit stack, like this:

    class Solution:
        def isSymmetric(self, root):
            if root:
                stack = [(root.left, root.right)]
                while stack:
                    a, b = stack.pop()
                    while a or b: # while the current pair has at least one node
                        # if there is a difference, the trees are not symmetrical
                        if not a or not b or a.val != b.val:
                            return False
                        # put one child pair on the stack for later inspection
                        stack.append((a.left, b.right))
                        # take the other child pair as current pair, and repeat
                        a = a.right
                        b = b.left
            # all relevant pairs were compared and found equal
            return True
    

    This explicit stack will have pairs of nodes, where one pair consists of a node from the left tree and one from the right tree, which still need to be compared for symmetry. When the nodes in such a pair have the same value, two "child" pairs are considered: one is put on the stack for later inspection, and the other pair becomes the current pair. That way all relevant pairs can be visited and compared.

    Space complexity

    Even though this iterative version will not be limited by the call stack limit, it could still use O(𝑛) auxiliary memory in the worst case, like when the input tree is symmetric and has all its nodes on two long branches forming the shape of a tent:

                 /\
                /  \
               /    \
              ..    ..
             /        \
    

    It is a pity that the caller of this function can easily design a tree that realises this worst case memory usage. To avoid that, you could randomly choose which pair will be pushed on the stack, and which one will be used as the current pair. That way the caller cannot intentionally provide a worst-case tree in terms of memory complexity.

    import random
    
    class Solution:
        def isSymmetric(self, root):
            if root:
                stack = [(root.left, root.right)]
                while stack:
                    a, b = stack.pop()
                    while a or b:
                        if not a or not b or a.val != b.val:
                            return False
                        # randomly decide which pair will be put on the stack
                        if random.randrange(2):
                            a, b = b, a
                        stack.append((a.left, b.right))
                        a = a.right
                        b = b.left
            return True