I have a JAX function cart_deriv()
which takes another function f
and returns the Cartesian derivative of f
, implemented as follows:
@partial(custom_vjp, nondiff_argnums=0)
def cart_deriv(f: Callable[..., float],
l: int,
R: Array
) -> Array:
df = lambda R: f(l, jnp.dot(R, R))
for i in range(l):
df = jacrev(df)
return df(R)
def cart_deriv_fwd(f, l, primal):
primal_out = cart_deriv(f, l, primal)
residual = cart_deriv(f, l+1, primal) ## just a test
return primal_out, residual
def cart_deriv_bwd(f, residual, cotangent):
cotangent_out = jnp.ones(3) ## just a test
return (None, cotangent_out)
cart_deriv.defvjp(cart_deriv_fwd, cart_deriv_bwd)
if __name__ == "__main__":
def test_func(l, r2):
return l + r2
primal_out, f_vjp = vjp(cart_deriv,
jax.tree_util.Partial(test_func),
2,
jnp.array([1., 2., 3.])
)
cotangent = jnp.ones((3, 3))
cotangent_out = f_vjp(cotangent)
print(cotangent_out[1].shape)
However this code produces the error:
TypeError: cart_deriv_bwd() missing 1 required positional argument: 'cotangent'
I have checked that the syntax agrees with that in the documentation. I'm wondering why the argument cotangent
is not recognized by vjp
, and how to fix this error?
The issue is that nondiff_argnums
is expected to be a sequence:
@partial(custom_vjp, nondiff_argnums=(0,))
With this properly defined, it's better to avoid wrapping the function in Partial
, and just pass it as a static argument by closing over it in the vjp
call:
primal_out, f_vjp = vjp(partial(cart_deriv, test_func),
2,
jnp.array([1., 2., 3.])
)
cotangent_out = f_vjp(jnp.ones((3, 3)))
print(*cotangent_out)
# (b'',) [1. 1. 1.]