I want to implement something like the following Python function in Jax, and wrap it with a call to vmap
. I want it to be fully reverse-mode differentiable (with respect to x
) using grad()
, even after the vmap.
def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])
(This is a deliberately simplified version of the function; I realize in this case I could use the closed-form expression for the geometric series; sadly the actual function I'm trying to implement does not have a closed-form sum that I'm aware of.)
Is there any way to do this? It seems like there has to be; but fori_loop
is not reverse-mode differentiable if kmax
is dynamic, jax.lax.scan
needs a statically-shaped array or it will throw ConcretizationTypeError
s, and similarly Python primitives like range
(as used above) throw TracerIntegerConversionError
if wrapped in vmap
.
I think I understand the restrictions on needing arrays to be fixed-shape, but every autodiff framework I've ever used allows you to construct arbitrarily-sized expressions dynamically somehow. A sum over a varying integer range is a pretty basic mathematical tool. How does one implement this in Jax?
EDITED to refocus the problem definition (the issue is more vmap than grad) and provide the following examples.
This, specifically, is what I'd like to be able to do
import jax
def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])
fmap = jax.vmap(f,in_axes=(None,-1))
x = 3.
kmaxes = jax.numpy.array([1,2,3])
print(fmap(x,kmaxes))
fmap_sum = lambda k,kmaxes:jax.numpy.sum(fmap(k,kmaxes))
print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))
This throws a TracerIntegerConversionError at range(1,kmax+1)
.
What I would like it to be doing is something like this:
import jax
def f(x,kmax):
return sum ([x**k for k in range(1,kmax+1)])
def fmap(x,kmaxes):
return [f(x,kmax) for kmax in kmaxes]
x = 3.
kmaxes = jax.numpy.array([1,2,3])
print(fmap(x,kmaxes))
def fmap_sum(x,kmaxes):
return sum(fmap(x,kmaxes))
print(fmap_sum(x,kmaxes))
print(jax.grad(fmap_sum)(x,kmaxes))
which gives the correct result (but loses the parallelization and acceleration of vmap).
First, to make your function compatible with vmap
, you'll need to replace the Python control flow with jax.lax
control flow operations. In this case, lax.fori_loop
seems applicable:
def f1(x, k):
def body_fun(i, val):
return val + x ** i
return jax.lax.fori_loop(1, k + 1, body_fun, jnp.zeros_like(x))
f1map = jax.vmap(f1, (None, 0))
print(f1map(x, kmaxes))
# [ 3. 12. 39.]
But because the size of the loop is dynamic, this is not compatible with reverse-mode autodiff:
jax.jacrev(f1map)(x, kmaxes)
# ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.
To get around this, you can modify your function such that it uses a static loop size. Here's one way you might do that:
def f2(x, k, kmax): # kmax should be static
def body_fun(i, val):
return val + jnp.where(i <= k, x ** i, 0)
return jax.lax.fori_loop(1, kmax + 1, body_fun, jnp.zeros_like(x))
f2map = jax.vmap(f2, (None, 0, None))
print(f2map(x, kmaxes, kmaxes.max())) # compatible with vmap
# [ 3. 12. 39.]
print(jax.jacrev(f2map)(x, kmaxes, kmaxes.max())) # and with reverse-mode autodiff
# [ 1. 7. 34.]