I'm writing an interpolation routine and have a dictionary which stores the function values at the fitting points. Ideally, the dictionary keys would be 2D Numpy arrays of the fitting point coordinates, np.array([x, y])
, but since Numpy arrays aren't hashable these are converted to tuples for the keys.
# fit_pt_coords: (n_pts, n_dims) array
# fn_vals: (n_pts,) array
def fit(fit_pt_coords, fn_vals):
pt_map = {tuple(k): v for k, v in zip(fit_pt_coords, fn_vals)}
...
Later in the code I need to get the function values using coordinates as keys in order to do the interpolation fitting. I'd like this to be within @jax.jit
ed code, but the coordinate values are of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
, which can't be converted to a tuple. I've tried other things, like creating a dictionary key as (x + y, x - y)
, but again this requires concrete values, and calling .item()
results in an ConcretizationTypeError
.
At the moment I've @jax.jit
ed all of the code I can, and have just left this code un-jitted. It would be great if I could jit this code as well however. Are there any better ways to do the dictionary indexing (or better Jax-compatible data structures) which would allow all of the code to be jitted? I am new to Jax and still understading how it works, so I'm sure there must be better ways of doing it...
There is no way to use traced JAX values as dictionary keys. The problem is that the key values will not be known until runtime within the XLA compiler, and XLA has no dictionary-like data structure that such lookups can be lowered to.
There are imperfect solutions, such as keeping the dictionary on the host and using something like io_callback
to do the dict lookups on host, but this approach comes with performance penalties that will likely make it impractical.
Unfortunately, your best approach for doing this efficiently under JIT would probably be to switch to a different interpolation algorithm that doesn't depend on hash table lookups.