I was asked the following question in a job interview:
Given a root node (to a well formed binary tree) and two other nodes (which are guaranteed to be in the tree, and are also distinct), return the lowest common ancestor of the two nodes.
I didn't know any least common ancestor algorithms, so I tried to make one on the spot. I produced the following code:
def least_common_ancestor(root, a, b):
lca = [None]
def check_subtree(subtree, lca=lca):
if lca[0] is not None or subtree is None:
return 0
if subtree is a or subtree is b:
return 1
else:
ans = sum(check_subtree(n) for n in (subtree.left, subtree.right))
if ans == 2:
lca[0] = subtree
return 0
return ans
check_subtree(root)
return lca[0]
class Node:
def __init__(self, left, right):
self.left = left
self.right = right
I tried the following test cases and got the answer that I expected:
a = Node(None, None)
b = Node(None, None)
tree = Node(Node(Node(None, a), b), None)
tree2 = Node(a, Node(Node(None, None), b))
tree3 = Node(a, b)
but my interviewer told me that "there is a class of trees for which your algorithm returns None." I couldn't figure out what it was and I flubbed the interview. I can't think of a case where the algorithm would make it to the bottom of the tree without ans
ever becoming 2 -- what am I missing?
You forgot to account for the case where a
is a direct ancestor of b
, or vice versa. You stop searching as soon as you find either node and return 1
, so you'll never find the other node in that case.
You were given a well-formed binary search tree; one of the properties of such a tree is that you can easily find elements based on their relative size to the current node; smaller elements are going into the left sub-tree, greater go into the right. As such, if you know that both elements are in the tree you only need to compare keys; as soon as you find a node that is in between the two target nodes, or equal to one them, you have found lowest common ancestor.
Your sample nodes never included the keys stored in the tree, so you cannot make use of this property, but if you did, you'd use:
def lca(tree, a, b):
if a.key <= tree.key <= b.key:
return tree
if a.key < tree.key and b.key < tree.key:
return lca(tree.left, a, b)
return lca(tree.right, a, b)
If the tree is merely a 'regular' binary tree, and not a search tree, your only option is to find the paths for both elements and find the point at which these paths diverge.
If your binary tree maintains parent references and depth, this can be done efficiently; simply walk up the deeper of the two nodes until you are at the same depth, then continue upwards from both nodes until you have found a common node; that is the least-common-ancestor.
If you don't have those two elements, you'll have to find the path to both nodes with separate searches, starting from the root, then find the last common node in those two paths.