pythonjitjax

JIT: partial or with static argnums? Non hashable input, but hashable partial


I am a bit lost on what exactly going on and what option to choose. Let's go trough an example:

import jax
from functools import partial
from typing import List

def dummy(a: int, b: List[str]):
    return a + 1

As b argument is mutable, jitting with static argnames will be failed:

j_dummy = jax.jit(dummy, static_argnames=['b'])
j_dummy(2, ['kek'])
ValueError: Non-hashable static arguments are not supported

However, if we do partial: jp_dummy = jax.jit(partial(dummy, b=['kek'])), we aim the goal. Somehow, partial object is indeed has __hash__ method, so we can check it with hash(partial(dummy, b=['kek'])).

So, I am a bit lost here: how I should proceed in a bigger picture? Should I produce partial functions with whatever arguments and then jit them or should I try to maintain my arguments hashable? What are situations when one approach is better than other? Is there any drawbacks?


Solution

  • When you use static_argnames, the static values passed to the function become part of the cache key, so if the value changes the function is re-compiled:

    import jax
    import jax.numpy as jnp
    
    def f(x, s):
      return x * len(s)
    
    f_jit = jax.jit(f, static_argnames=['s'])
    
    print(f_jit(2, "abc"))  # 6
    print(f_jit(2, "abcd"))  # 8
    

    This is why the static arguments must be hashable: their hash is used as the JIT cache key.

    On the other hand, when you wrap a static argument via closure, its value does not affect the cache key, and so it need not be hashable. On the other hand, since it's not part of the cache key, if the global value changes, it does not trigger a recompilation and so you may get unexpected results:

    f_closure = jax.jit(lambda x: f(x, s))
    
    s = "abc"
    print(f_closure(2))  # 6
    s = "abcd"
    print(f_closure(2))  # 6
    

    For this reason, explicit static arguments can be safer. In your case, it may be best to change your list into a tuple, as tuples are hashable and can be used as explicit static arguments.