pythoncachinglis

Longest Increasing Subsequence using recursion and cache


i've been trying to implement a cache in my recursive LIS function so it doesn't calculate the same value twice. I would really aprecciate if someone can give me a hint of what im getting wrong.

This is the recursive function that returns the LIS array that works fine:

import numpy as np

def lgs(l):
    return lgsRecursive(np.NINF,l,0)
    
def lgsRecursive(x,l,i):
    print(x,i)
    
    if i >= len(l):
        return[]
        
    else:
        list1 = lgsRecursive(x,l,i+1)
        if l[i] > x:
            list2 = [l[i]] + lgsRecursive(l[i],l,i+1)
            if len(list1) < len(list2):
                list1 = list2
                
    return list1

assert(lgs([1, 20, 3, 7, 40, 5, 2]) == [1,3,7,40])

This is the same function but implementing a cache, it gives wrong answers with repetition(in the case of the previous assert it returns [1, 20, 40, 40, 40, 40, 40]):

import numpy as np
cache = {}

def lgs(l):
    return lgsMemo(np.NINF,l,0)

def lgsMemo(x,l,i):
    global cache
    
    key = (x,i)
    
    if key in cache:
        return cache[(x,i)]
    
    if i >= len(l):
        return []
    
    else:
        list1 = lgsMemo(x,l,i+1)
        if l[i] > x:
            list2 = [l[i]] + lgsMemo(l[i],l,i+1)
            if len(list1) < len(list2):
                list1 = list2
                cache[(l[i],i+1)] = list1
            else:
                cache[(x,i+1)] = list1                  
    return list1

I think maybe the error is caching [l[i]] + lgsMemo(l[i],l,i+1) instead of lgsMemo(l[i],l,i+1).


Solution

  • Why make it so hard on yourself? You can just have two functions. One you call if you need to actually calculate things and one where you check if you have it in memory and delegate to the other if necessary. Notice that I had to slightly edit your recursive function so it uses the cache if possible.

    import numpy as np
    cache = {}
    
    def lgs(l):
        return lgsMemo(np.NINF,l,0)
        
    def lgsRecursive(x,l,i):
        print(x,i)
        
        if i >= len(l):
            return[]
            
        else:
            list1 = lgsMemo(x,l,i+1)
            if l[i] > x:
                list2 = [l[i]] + lgsMemo(l[i],l,i+1)
                if len(list1) < len(list2):
                    list1 = list2
                    
        return list1
    
    
    def lgsMemo(x,l,i):
        global cache
        if (x,i) not in cache:
            cache[(x,i)] = lgsRecursive(x,l,i)
        return cache[(x,i)]
    
    
    assert(lgs([1, 20, 3, 7, 40, 5, 2]) == [1,3,7,40])