pythonfunctionjaxautomatic-differentiation

JAX `vjp` does not recognize cotangent argument with `custom_vjp`


I have a JAX function cart_deriv() which takes another function f and returns the Cartesian derivative of f, implemented as follows:

@partial(custom_vjp, nondiff_argnums=0)
def cart_deriv(f: Callable[..., float],
               l: int,
               R: Array
               ) -> Array:

    df = lambda R: f(l, jnp.dot(R, R))

    for i in range(l):
        df = jacrev(df)

    return df(R)


def cart_deriv_fwd(f, l, primal):

    primal_out = cart_deriv(f, l, primal)
    residual = cart_deriv(f, l+1, primal)  ## just a test

    return primal_out, residual


def cart_deriv_bwd(f, residual, cotangent):

    cotangent_out = jnp.ones(3)  ## just a test

    return (None, cotangent_out)


cart_deriv.defvjp(cart_deriv_fwd, cart_deriv_bwd)



if __name__ == "__main__":

    def test_func(l, r2):
        return l + r2

    primal_out, f_vjp = vjp(cart_deriv, 
                            jax.tree_util.Partial(test_func),
                            2,
                            jnp.array([1., 2., 3.])
                            )

    cotangent = jnp.ones((3, 3))
    cotangent_out = f_vjp(cotangent)

    print(cotangent_out[1].shape)

However this code produces the error:

TypeError: cart_deriv_bwd() missing 1 required positional argument: 'cotangent'

I have checked that the syntax agrees with that in the documentation. I'm wondering why the argument cotangent is not recognized by vjp, and how to fix this error?


Solution

  • The issue is that nondiff_argnums is expected to be a sequence:

    @partial(custom_vjp, nondiff_argnums=(0,))
    

    With this properly defined, it's better to avoid wrapping the function in Partial, and just pass it as a static argument by closing over it in the vjp call:

    primal_out, f_vjp = vjp(partial(cart_deriv, test_func),
                            2,
                            jnp.array([1., 2., 3.])
                            )
    cotangent_out = f_vjp(jnp.ones((3, 3)))
    
    print(*cotangent_out)
    # (b'',) [1. 1. 1.]