pythonnumpyjaxautodiff

How to wrap a numpy function to make it work with jax.numpy?


I have some Jax code that requires using auto differentiation and in part of the code, I would like to call a function from a library written in NumPy. When I try this now I get

The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4,22324])>with<JVPTrace(level=4/1)> with
  primal = Traced<ShapedArray(float32[4,22324])>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float32[4,22324])>with<JaxprTrace(level=3/1)> with
    pval = (ShapedArray(float32[4,22324]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fa89e8ffa80>, in_tracers=(Traced<ShapedArray(float32[22324,4]):JaxprTrace(level=3/1)>,), out_tracer_refs=[<weakref at 0x7fa89beb15e0; to 'JaxprTracer' at 0x7fa893b5ab80>], out_avals=[ShapedArray(float32[4,22324])], primitive=transpose, params={'permutation': (1, 0)}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7fa89e9312b0>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

which makes sense because NumPy is not auto-differentiable.

Is there any way to wrap a function written in NumPy such that it maps it to the jax.numpy equivalent?

A dirty way to make this work would be to modify the library so it calls jax.numpy instead of numpy but this makes applicability harder.

Thanks!


Solution

  • Edit January 2023: JAX is now adding a number of callback methods to accomplish this kind of thing; see https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html

    No, in general there's no way given a function that operates on NumPy arrays to automatically convert it to an equivalent function implemented in JAX. The reason for this is that JAX is not a 100% faithful one-to-one implementation of NumPy's API; rather you should think of jax.numpy as a NumPy-like wrapper around the functionality that JAX provides.

    As a simple example, consider this code:

    np.array(['A', 'B', 'C'])
    

    This has no JAX equivalent, because JAX/XLA does not support string arrays.

    If you want to use JAX transforms like autodiff on your code, there's not really any shortcut around rewriting your code in JAX. You can likely get a long way by replacing import numpy as np with import jax.numpy as jnp, so long as you're not using external libraries (like SciPy, Scikit-Learn, etc.) that operate on your arrays.

    Additionally, as you do such replacements, keep in mind JAX's Sharp Bits, which are places where jax.numpy may behave differently than your original NumPy code.