pythonclosuresdecorator

How to decorate instance methods and avoid sharing closure environment between instances


I'm having trouble finding a solution to this problem. Whenever we decorate a method of a class, the method is not yet bound to any instance, so say we have:

from functools import wraps

def decorator(f):
    closure_variable = 0

    @wraps(f)
    def wrapper(*args, **kwargs):
        nonlocal closure_variable
        closure_variable += 1
        print(closure_variable)
        f(*args, **kwargs)
        return

    return wrapper

class ClassA:
    @decorator
    def decorated_method(self):
        pass

This leads to something funny, which is all instances of ClassA are bound to the same closure environment.

inst1 = ClassA()
inst2 = ClassA()
inst3 = ClassA()

inst1.decorated_method()
inst2.decorated_method()
inst3.decorated_method()

The above lines will output:

1
2
3

Now to my issue at hand, I had created a decorator which caches a token and only requests a new one once it expires. This decorator was applied to a method of a class called TokenSupplier. I realized this behavior and I clearly don't want this to happen, can I solve this issue and keep the decorator design pattern here?

I thought of storing a dictionary in the closure environment and using the instance hash to index the desired data but I believe I might be simply missing something more fundamental. My goal would be to have each instance having it's own closure environment but still being able to use a decorator pattern to decorate different future TokenSupplier implementations.

Thank you in advance!


Solution

  • In order to avoid sharing the cache across all instances, which may not be required or desired, it is best to have a cache for each instance with expiry time, etc. In other words, we don't need to have a "single source cache" for all instances.

    In the following implementation, each and every instance of a class initializes its own cache dict() to store the token, its expiration time and other relevant info, that will give you the full control.

    from functools import wraps
    import time
     
     
    class TokenCacheDecorator:
        def __init__(self, get_token_func):
            self.get_token_func = get_token_func
     
        def __get__(self, inst, owner):
            if inst is None:
                return self
     
            @wraps(self.get_token_func)
            def wrapper(*args, **kwargs):
                if not hasattr(inst, '_token_cache') or inst._token_cache['expiration_time'] < time.time():
                    print(f"[{id(inst)}] Cache miss")
                    token, expires_in = self.get_token_func(inst, *args, **kwargs)
                    inst._token_cache = {
                        'token': token,
                        'expiration_time': time.time() + expires_in
                    }
                    print(f"[{id(inst)}] New token - {token} expiration time: {inst._token_cache['expiration_time']}")
     
                return inst._token_cache['token']
     
            return wrapper
     
     
    class ClassA:
        def __init__(self, token, expires_in):
            self.token = token
            self.expires_in = expires_in
            self._token_cache = {'token': None, 'expiration_time': 0}
     
        @TokenCacheDecorator
        def get_token(self):
            return self.token, self.expires_in
     
     
    inst1 = ClassA("token1", 2)
    inst2 = ClassA("token2", 2)
    inst3 = ClassA("token3", 2)
     
    print(inst1.get_token())
    print(inst2.get_token())
    print(inst3.get_token())
     
    time.sleep(3)
     
    print(inst1.get_token())
    print(inst2.get_token())
    print(inst3.get_token())
    
    

    Prints

    [4439687776] Cache miss
    [4439687776] New token - token1 expiration time: 1716693215.503801
    token1
    [4440899024] Cache miss
    [4440899024] New token - token2 expiration time: 1716693215.503846
    token2
    [4440899360] Cache miss
    [4440899360] New token - token3 expiration time: 1716693215.503862
    token3
    [4439687776] Cache miss
    [4439687776] New token - token1 expiration time: 1716693218.5076532
    token1
    [4440899024] Cache miss
    [4440899024] New token - token2 expiration time: 1716693218.50767
    token2
    [4440899360] Cache miss
    [4440899360] New token - token3 expiration time: 1716693218.507679
    token3