For a particular JAX function func
, one can define non-differentiable arguments by using the decorator @partial(jax.custom_jvp, nondiff_argnums=...)
. However, in order to make it work, one must also explicitly define the differentiation rules in a custom jvp
function by using the decorator @func.defjvp
. I'm wondering if there is a generic way to define non-differentiable arguments for any given func
, without defining a custom jvp
(or vjp
) function? This will be useful when the differentiation rules are too complicated to write out.
In JAX's design, non-differentiated arguments are a property of the gradient transformation being used, not a property of the function being differentiated. custom_jvp
is fundamentally about customizing the gradient behavior, and using it to mark non-differentiable arguments without actually customizing the gradient is not an intended use.
The way to ensure that arguments do not participate in an autodiff transformation is to specify the arguments you want to differentiate against when you call the jax.grad
, jax.jacobian
, or other autodiff transformation; e.g.
jax.grad(func, argnums=(0,)) # differentiate with respect to argument 0.
Regardless of what func
is, this will attempt to differentiate with respect to the 0th argument, and if that argument is either explicitly or implicitly not differentiable due to how func
is defined, an error will be raised.