How does one convert the following (to accelerate compiling)?
The for
loop version works with jax.jit
,
import functools
import jax
import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=0)
def func(n):
p = 1
x = jnp.arange(8)
y = jnp.zeros((n,))
for idx in range(n):
y = y.at[idx].set(jnp.sum(x[::p]))
p = 2*p
return y
func(2)
# >> Array([28., 12.], dtype=float32)
but will return static start/stop/step
errors when using scan
import numpy as np
def body(p, xi):
y = jnp.sum(x[::p])
p = 2*p
return p, y
x = jnp.arange(8)
jax.lax.scan(body, 1, np.arange(2))
# >> IndexError: Array slice indices must have static start/stop/step ...
The issue here is that within scan
, the p
variable represents a dynamic value, meaning that x[::p]
is a dynamically-sized array, so the operation is not allowed in JAX transformations (see JAX sharp bits: dynamic shapes).
Often in such cases it's possible to replace approaches using dynamically-shaped intermediates with other approaches that compute the same thing using only use static arrays; in this case one thing you might do is replace this problematic line:
jnp.sum(x[::p])
with this, which does the same sum using only statically-sized arrays:
jnp.sum(x, where=jnp.arange(len(x)) % p == 0)
Using this idea, here's a version of your original function that uses scan
:
import numpy as np
@functools.partial(jax.jit, static_argnums=0)
def func_scan(n):
p = 1
x = jnp.arange(8)
y = jnp.zeros((n,))
def body(carry, _):
idx, y, p = carry
y = y.at[idx].set(jnp.sum(x, where=jnp.arange(len(x)) % p == 0))
return (idx + 1, y, 2 * p), None
(i, y, p), _ = jax.lax.scan(body, (0, y, p), xs=None, length=n)
return y
func_scan(2)
# Array([28., 12.], dtype=float32)