pythonrecursiondecoratorpython-contextvars

*Dynamically* decorate a recursive function in Python


I have a scenario where I need to dynamically decorate recursive calls within a function in Python. The key requirement is to achieve this dynamically without modifying the function in the current scope. Let me explain the situation and what I've tried so far.

Suppose I have a function traverse_tree that recursively traverses a binary tree and yields its values. Now, I want to decorate the recursive calls within this function to include additional information, such as the recursion depth. When I use a decorator directly with the function, it works as expected. However, I want to achieve the same dynamically, without modifying the function in the current scope.


import functools


class Node:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right


def generate_tree():
    root = Node(1)
    root.left = Node(2)
    root.right = Node(3)
    root.left.left = Node(4)
    root.left.right = Node(5)
    root.right.left = Node(6)
    root.right.right = Node(7)
    return root


def with_recursion_depth(func):
    """Yield recursion depth alongside original values of an iterator."""
    
    class Depth(int): pass
    depth = Depth(-1)

    def depth_in_value(value, depth) -> bool:
        return isinstance(value, tuple) and len(value) == 2 and value[-1] is depth

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal depth
        depth = Depth(depth + 1)
        for value in func(*args, **kwargs):
            if depth_in_value(value, depth):
                yield value
            else:
                yield value, depth
        depth = Depth(depth - 1)

    return wrapper

# 1. using @-syntax
@with_recursion_depth
def traverse_tree(node):
    """Recursively yield values of the binary tree."""
    yield node.value
    if node.left:
        yield from traverse_tree(node.left)
    if node.right:
        yield from traverse_tree(node.right)


root = generate_tree()
for item in traverse_tree(root):
    print(item)

# Output:
# (1, 0)
# (2, 1)
# (4, 2)
# (5, 2)
# (3, 1)
# (6, 2)
# (7, 2)


# 2. Dynamically:  
def traverse_tree(node):
    """Recursively yield values of the binary tree."""
    yield node.value
    if node.left:
        yield from traverse_tree(node.left)
    if node.right:
        yield from traverse_tree(node.right)


root = generate_tree()
for item in with_recursion_depth(traverse_tree)(root):
    print(item)

# Output:
# (1, 0)
# (2, 0)
# (4, 0)
# (5, 0)
# (3, 0)
# (6, 0)
# (7, 0)

It seems that the issue lies in how the recursive calls within the function are decorated. When using the decorator dynamically it only decorates the outer function calls and not the recursive calls made within the function. I can achieve this by re-assigning (traverse_tree = with_recursion_depth(traverse_tree)), but now the function has been modified in the current scope. I would like to achieve this dynamically so I can either use the non-decorated function, or optionally wrap it to obtain information on the recursion depth.

I prefer to keep things simple and would like to avoid techniques like bytecode manipulation if there are alternative solutions. However, if that's the necessary path, I'm willing to explore it. I've made an attempt in that direction, but I haven't been successful yet.

import ast


def modify_recursive_calls(func, decorator):

    def decorate_recursive_calls(node):
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == func.__name__:
            func_name = ast.copy_location(ast.Name(id=node.func.id, ctx=ast.Load()), node.func)
            decorated_func = ast.Call(
                func=ast.Name(id=decorator.__name__, ctx=ast.Load()),
                args=[func_name],
                keywords=[],
            )
            node.func = decorated_func
        for field, value in ast.iter_fields(node):
            if isinstance(value, list):
                for item in value:
                    if isinstance(item, ast.AST):
                        decorate_recursive_calls(item)
            elif isinstance(value, ast.AST):
                decorate_recursive_calls(value)

    tree = ast.parse(inspect.getsource(func))
    decorate_recursive_calls(tree)
    ast.fix_missing_locations(tree)
    modified_code = compile(tree, filename="<ast>", mode="exec")
    modified_function = types.FunctionType(modified_code.co_consts[1], func.__globals__)
    return modified_function

Solution

  • This is a case where context variables can be used: The "contextvars" is a somewhat recent addition to the language, designed so that asynchronous code running in tasks can pass "out of band" information for nested calls, without that information being overwritten or confused by other tasks calling the same functions. Check: https://docs.python.org/3/library/contextvars.html

    They are somewhat like threading.locals, but with an ugly interface - your innermost calls can retrieve the function to be called from a contextvar instead of the global scope, and so, if the outermost call set this context var to the decorated function, only the functions called in that "descent" are affected.

    cotnextvars are robust enough. Their interface, however is horrible to use, as one must first create a context copy, and call a function using that context copy, and just this entered function can change the context variable. Code outside of that call, however, will always see the unchanged value, and that is consistent and robust across threads, async tasks, and whatever.

    With a simpler "mymul" recursive function and a "logger" decorator, the code can be like this:

    import contextvars
    
    recursive_func = contextvars.ContextVar("recursive_func")
    
    
    def logger(func):
        def wrapper(*args, **kwargs):
            print(f"{func.__name__} called with {args}  and {kwargs}")
            return func(*args, **kwargs)
        return wrapper
    
    
    def mymul(a, b):
        if b == 0:
            return 0
        return a + recursive_func.get()(a, b - 1)
    
    recursive_func.set(mymul)
    
    # non logging call:
    mymul(3, 4)
    
    # logging call - there must be an "entry point" which
    # can change the contextvar from within the new context.
    def logging_call_maker(*args, **kwargs):
        recursive_func.set(logger(mymul))
        return mymul(*args, **kwargs)
    
    
    contextvars.copy_context().run(logging_call_maker, 3, 4)
    
    
    # Another non logging call:
    mymul(3, 4)
    

    The key point here is: the recursive function retrieves the function to be called from the ContextVar, not using its name in the global scope.

    If you like the approach, but the ContextVar approach seems to be too much boilerplate, this can be made easier to use: just drop a comment here: I have a project I started years ago to wrap this behavior in some friendlier code, but the lack of use cases for myself lead me to stop fiddling with it.