pythonjax

JAX, recompilation when using closure for a function


I have a jax code where I would like to scan over an array. In the body function of the scan, I have a pytree to store some parameters and functions that I want to apply during the scan. For the scan, I used lambda to bake in the object/pytree named params.

Does this trigger a new compilation when a new params is passed in the function example? If so, how can I avoid the recompilation?

import jax
import jax.numpy as jnp
from jax import tree_util

class Params:
    def __init__(self, x_array, a):    
        self.x_array = x_array
        self.a = a
        return
    
    def one_step(self,state, input):
        x = state
        y = input
        next_state = (self.x_array + x + jnp.ones(self.a))*y
        return next_state

    def _tree_flatten(self):
        children = (self.x_array,)
        aux_data = {'a':self.a}
        return (children, aux_data)
    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)
        
tree_util.register_pytree_node(Params,
                               Params._tree_flatten,
                               Params._tree_unflatten)

def scan_body(params, state, input):
    x = state
    y = input 
    x_new = params.one_step(x, y) 
    return x_new, [x_new]

@jax.jit
def example(params):
    body_fun = lambda state, input: scan_body(params, state, input)
    init_state = jnp.array([0.,1.])
    input_array = jnp.array([1.,2.,3.])
    last_state, result_list = jax.lax.scan(body_fun, init_state, input_array)
    return last_state, result_list

if __name__ == "__main__":

    params1 = Params(jnp.array([1.,2.]), 2)
    last_state, result_list = example(params1)
    print(last_state)

    params2 = Params(jnp.array([3.,4.]), 2)
    last_state, result_list = example(params2)
    print(last_state)

Solution

  • Passing a new params object would only trigger recompilation if the static attributes of your params were to change. Since aux_data is static, changing the value of params.a will lead to re-compilation. Since children are dynamic, then changing the shape, dtype, or sharding of params.x will lead to recompilation, but changing the array values/contents will not.

    In your example, in both calls params.x has the same shape, dtype, and sharding, and params.a has the same value, so there should not be any recompilation (if you're unsure, you could confirm this using the approach mentioned at https://stackoverflow.com/a/70127930/2937831).

    Note in particular that the lambda functions used in the method implementations cannot affect the JIT cache key because they are not referenced in the pytree flattening output.