pythonindex-errorjax

Convert for loop to jax.lax.scan


How does one convert the following (to accelerate compiling)? The for loop version works with jax.jit,

import functools
import jax
import jax.numpy as jnp

@functools.partial(jax.jit, static_argnums=0)
def func(n):

    p = 1
    x = jnp.arange(8)
    y = jnp.zeros((n,))

    for idx in range(n):
        y = y.at[idx].set(jnp.sum(x[::p]))
        p = 2*p

    return y

func(2)
# >> Array([28., 12.], dtype=float32)

but will return static start/stop/step errors when using scan

import numpy as np

def body(p, xi):

    y = jnp.sum(x[::p])

    p = 2*p

    return p, y

x = jnp.arange(8)

jax.lax.scan(body, 1, np.arange(2))
# >> IndexError: Array slice indices must have static start/stop/step ...

Solution

  • The issue here is that within scan, the p variable represents a dynamic value, meaning that x[::p] is a dynamically-sized array, so the operation is not allowed in JAX transformations (see JAX sharp bits: dynamic shapes).

    Often in such cases it's possible to replace approaches using dynamically-shaped intermediates with other approaches that compute the same thing using only use static arrays; in this case one thing you might do is replace this problematic line:

    jnp.sum(x[::p])
    

    with this, which does the same sum using only statically-sized arrays:

    jnp.sum(x, where=jnp.arange(len(x)) % p == 0)
    

    Using this idea, here's a version of your original function that uses scan:

    import numpy as np
    
    @functools.partial(jax.jit, static_argnums=0)
    def func_scan(n):
        p = 1
        x = jnp.arange(8)
        y = jnp.zeros((n,))
    
        def body(carry, _):
          idx, y, p = carry
          y = y.at[idx].set(jnp.sum(x, where=jnp.arange(len(x)) % p == 0))
          return (idx + 1, y, 2 * p), None
    
        (i, y, p), _ = jax.lax.scan(body, (0, y, p), xs=None, length=n)
        return y
    
    func_scan(2)
    # Array([28., 12.], dtype=float32)