I have a jax code where I would like to scan over an array. In the body function of the scan, I have a pytree
to store some parameters and functions that I want to apply during the scan. For the scan, I used lambda
to bake in the object/pytree named params
.
Does this trigger a new compilation when a new params
is passed in the function example
? If so, how can I avoid the recompilation?
import jax
import jax.numpy as jnp
from jax import tree_util
class Params:
def __init__(self, x_array, a):
self.x_array = x_array
self.a = a
return
def one_step(self,state, input):
x = state
y = input
next_state = (self.x_array + x + jnp.ones(self.a))*y
return next_state
def _tree_flatten(self):
children = (self.x_array,)
aux_data = {'a':self.a}
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
tree_util.register_pytree_node(Params,
Params._tree_flatten,
Params._tree_unflatten)
def scan_body(params, state, input):
x = state
y = input
x_new = params.one_step(x, y)
return x_new, [x_new]
@jax.jit
def example(params):
body_fun = lambda state, input: scan_body(params, state, input)
init_state = jnp.array([0.,1.])
input_array = jnp.array([1.,2.,3.])
last_state, result_list = jax.lax.scan(body_fun, init_state, input_array)
return last_state, result_list
if __name__ == "__main__":
params1 = Params(jnp.array([1.,2.]), 2)
last_state, result_list = example(params1)
print(last_state)
params2 = Params(jnp.array([3.,4.]), 2)
last_state, result_list = example(params2)
print(last_state)
Passing a new params
object would only trigger recompilation if the static attributes of your params
were to change. Since aux_data
is static, changing the value of params.a
will lead to re-compilation. Since children
are dynamic, then changing the shape, dtype, or sharding of params.x
will lead to recompilation, but changing the array values/contents will not.
In your example, in both calls params.x
has the same shape, dtype, and sharding, and params.a
has the same value, so there should not be any recompilation (if you're unsure, you could confirm this using the approach mentioned at https://stackoverflow.com/a/70127930/2937831).
Note in particular that the lambda
functions used in the method implementations cannot affect the JIT cache key because they are not referenced in the pytree flattening output.