pythonconditional-statementsjitjax

Equivalent of `jax.lax.cond` for multiple boolean conditions


Currently jax.lax.cond works for one boolean condition. Is there a way to extend it to multiple boolean conditions?

As an example, below is an untraceable function:

def func(x):
    if x < 0: return x
    elif (x >= 0) & (x < 1): return 2*x
    else: return 3*x

How to write this function in JAX in a traceable way?


Solution

  • One compact way to write something like this is using jnp.select:

    import jax
    import jax.numpy as jnp
    
    @jax.jit
    def func(x):
      return jnp.select([x < 0, x < 1], [x, 2 * x], default=3 * x)
    
    x = jnp.array([-0.5, 0.5, 1.5])
    print(func(x))
    # [-0.5  1.   4.5]