pythonjax

Mapping Over Arrays of Functions in JAX


What is the most performant, idiomatic way of mapping over arrays of functions in JAX?

Context: This GitHub issue shows a way to apply vmap to several functions using lax.switch. The example is reproduced below:

from jax import lax, vmap
import jax.numpy as jnp

def func1(x):
  return 2 * x

def func2(x):
  return -2 * x

def func3(x):
  return 0 * x

functions = [func1, func2, func3]
index = jnp.arange(len(functions))
x = jnp.ones((3, 5))

vmap_functions = vmap(lambda i, x: lax.switch(i, functions, x))
vmap_functions(index, x)
# DeviceArray([[ 2.,  2.,  2.,  2.,  2.],
#              [-2., -2., -2., -2., -2.],
#              [ 0.,  0.,  0.,  0.,  0.]], dtype=float32)

My specific questions are:


Solution

  • For the kind of operation you're doing, where the functions are applied over full axes of an array in a way that's known statically, you'll probably get the best performance via a simple Python loop:

    def map_functions(functions: list[Callable[[Array], Array], x: Array) -> Array:
      assert len(functions) == x.shape[0]
      return jnp.array([f(row) for f, row in zip(functions, x)])
    

    The method based on switch is designed for the more general case where the structure of the indices is not known statically.

    What performance penalties, if any, does this method incur? (This refers to both runtime and/or compile-time performance.)

    vmap of switch is implemented via select, which will compute the output of each function for the full input array before selecting just the pieces needed to construct the output, so if the functions are expensive to compute, it may lead to longer runtimes.