I find custom automatic differentiation capabilities (JVP, VJP) very useful in JAX, but am having a hard time applying it to higher order functions. A minimal example of this sort is as follows: given a higher order function:
def parent_func(x):
def child_func(y):
return x**2 * y
return child_func
I would like to define custom gradients of child_func
with respect to x and y. What would be the correct syntax to achieve this?
Gradients in JAX are defined with respect to a function’s explicit inputs. Your child_func
does not take x
as an explicit input, so you cannot directly differentiate child_func
with respect to x
. However, you could do so indirectly by calling it from another function that takes x
. For example:
def func_to_differentiate(x, y):
child_func = parent_func(x)
return child_func(y)
jax.grad(func_to_differentiate, argnums=0)(1.0, 1.0) # 2.0
Then if you wish, you could define standard custom derivative rules for func_to_differentiate
.