pythonrecursioncachingfunctoolspython-nonlocal

How to update nonlocal variables while caching results?


When using the functools caching functions like lru_cache, the inner function doesn't update the values of the non local variables. The same method works without the decorator.

Are the non-local variables not updated when using the caching decorator? Also, what to do if I have to update non-local variables but also store results to avoid duplicate work? Or do I need to return an answer from the cached function necessarily?

Eg. the following does not correctly update the value of the nonlocal variable

def foo(x):
    outer_var=0

    @lru_cache
    def bar(i):
        nonlocal outer_var
        if condition:
            outer_var+=1
        else:
            bar(i+1)

    bar(x)
    return outer_var

Background

I was trying the Decode Ways problem which is finding the number of ways a string of numbers can be interpreted as letters. I start from the first letter and take one or two steps and check if they're valid. On reaching the end of the string, I update a non local variable which stores the number of ways possible. This method is giving correct answer without using lru_cache but fails when caching is used. Another method where I return the value is working but I wanted to check how to update non-local variables while using memoization decorators.

My code with the error:

ways=0
@lru_cache(None) # works well without this
def recurse(i):
    nonlocal ways
    if i==len(s):
        ways+=1
    elif i<len(s):
        if 1<=int(s[i])<=9:
            recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            recurse(i+2)
    return 

recurse(0)
return ways

The accepted solution:

@lru_cache(None)
def recurse(i):
    if i==len(s):
        return 1

    elif i<len(s):
        ans=0
        if 1<=int(s[i])<=9:
            ans+= recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            ans+= recurse(i+2)
        return ans

return recurse(0)

Solution

  • There's nothing special about lru_cache, a nonlocal variable or recursion causing any inherent issue here, per se. The issue is purely logical rather than a behavioral anomaly. See this minimal example:

    from functools import lru_cache
    
    def foo():
        c = 0
    
        @lru_cache(None)
        def bar(i=0):
            nonlocal c
    
            if i < 5:
                c += 1
                bar(i + 1)
    
        bar()
        return c
    
    print(foo()) # => 5
    

    The problem in the cached version of decode ways code is due to the overlapping nature of the recursive calls. The cache prevents the base case call recurse(i) where i == len(s) from ever executing more than once, even if it's reached from a different recursive path.

    A good way to establish this is to slap a print("hello") in the base case (the if i == len(s) branch), then feed it a sizable problem. You'll see print("hello") fire once, and only once, and since ways cannot be updated by any other means than through recurse(i) when i == len(s), you're left with ways == 1 when all is said and done.

    In the above toy example, there's only one recursive path: the calls expand for each i between 0 and 9 and the cache is never used. In contrast, decode ways offers multiple recursive paths, so the path via recurse(i+1) finds the base case linearly, then as the stack unwinds, recurse(i+2) tries to find other ways of reaching it.

    Adding the cache cuts off extra paths, but it has no value to return for each intermediate node. With the cache, it's like you have a memoized or dynamic programming table of subproblems, but you never update any entries, so the whole table is zero (except for the base case).

    Here's an example of the linear behavior the cache causes:

    from functools import lru_cache
    
    def cached():
        @lru_cache(None)
        def cached_recurse(i=0):
            print("cached", i)
    
            if i < 3:
                cached_recurse(i + 1)
                cached_recurse(i + 2)
    
        cached_recurse()
    
    def uncached():
        def uncached_recurse(i=0):
            print("uncached", i)
    
            if i < 3:
                uncached_recurse(i + 1)
                uncached_recurse(i + 2)
    
        uncached_recurse()
    
    cached()
    uncached()
    

    Output:

    cached 0
    cached 1
    cached 2
    cached 3
    cached 4
    uncached 0
    uncached 1
    uncached 2
    uncached 3
    uncached 4
    uncached 3
    uncached 2
    uncached 3
    uncached 4
    

    The solution is exactly as you show: pass the results up the tree and use the cache to store values for each node representing a subproblem. This is the best of both worlds: we have the values for the subproblems, but without re-executing the functions that ultimately lead to your ways += 1 base case.

    In other words, if you're going to use the cache, think of it like a lookup table, not just a call tree pruner. In your attempt, it doesn't remember what work was done, just prevents it from being done again.