pythonfunctionhigher-order-functionsjaxautodiff

Custom JVP and VJP for higher order functions in JAX


I find custom automatic differentiation capabilities (JVP, VJP) very useful in JAX, but am having a hard time applying it to higher order functions. A minimal example of this sort is as follows: given a higher order function:

def parent_func(x):
    def child_func(y):
        return x**2 * y
    return child_func

I would like to define custom gradients of child_func with respect to x and y. What would be the correct syntax to achieve this?


Solution

  • Gradients in JAX are defined with respect to a function’s explicit inputs. Your child_func does not take x as an explicit input, so you cannot directly differentiate child_func with respect to x. However, you could do so indirectly by calling it from another function that takes x. For example:

    def func_to_differentiate(x, y):
      child_func = parent_func(x)
      return child_func(y)
    
    jax.grad(func_to_differentiate, argnums=0)(1.0, 1.0)  # 2.0
    

    Then if you wish, you could define standard custom derivative rules for func_to_differentiate.