I am a bit lost on what exactly going on and what option to choose. Let's go trough an example:
import jax
from functools import partial
from typing import List
def dummy(a: int, b: List[str]):
return a + 1
As b
argument is mutable, jitting with static argnames will be failed:
j_dummy = jax.jit(dummy, static_argnames=['b'])
j_dummy(2, ['kek'])
ValueError: Non-hashable static arguments are not supported
However, if we do partial
: jp_dummy = jax.jit(partial(dummy, b=['kek']))
, we aim the goal. Somehow, partial
object is indeed has __hash__
method, so we can check it with hash(partial(dummy, b=['kek']))
.
So, I am a bit lost here: how I should proceed in a bigger picture? Should I produce partial functions with whatever arguments and then jit them or should I try to maintain my arguments hashable? What are situations when one approach is better than other? Is there any drawbacks?
When you use static_argnames
, the static values passed to the function become part of the cache key, so if the value changes the function is re-compiled:
import jax
import jax.numpy as jnp
def f(x, s):
return x * len(s)
f_jit = jax.jit(f, static_argnames=['s'])
print(f_jit(2, "abc")) # 6
print(f_jit(2, "abcd")) # 8
This is why the static arguments must be hashable: their hash is used as the JIT cache key.
On the other hand, when you wrap a static argument via closure, its value does not affect the cache key, and so it need not be hashable. On the other hand, since it's not part of the cache key, if the global value changes, it does not trigger a recompilation and so you may get unexpected results:
f_closure = jax.jit(lambda x: f(x, s))
s = "abc"
print(f_closure(2)) # 6
s = "abcd"
print(f_closure(2)) # 6
For this reason, explicit static arguments can be safer. In your case, it may be best to change your list into a tuple, as tuples are hashable and can be used as explicit static arguments.