pythonjax

How can I apply member functions of a list of objects across slices of a JAX array using vmap?


I have a list of a objects, each of which has a function to be applied on a slice of a jax.numpy.array. There are n objects and n corresponding slices. How can I vectorise this using vmap?

For example, for the following code snippet:

import jax
import jax.numpy as jnp

class Obj:
    def __init__(self, i):
        self.i = i

    def f1(self, x): return (x - self.i)

x = jnp.arange(9).reshape(3, 3).astype(jnp.float32)

functions_obj = [Obj(1).f1, Obj(2).f1, Obj(3).f1]

how would I apply the functions in functions_obj to slices of x?

More details, probably not relevant: My specific use-case is running the member functions of a lot of Reinforcement Learning Gym environment objects on slices of an actions array, but I believe my problem is more general and I formulated it as above. (P.S.: I know about AsyncVectorEnv by the way but that does not solve my problem as I am not trying to run the step function).


Solution

  • Use jax.lax.switch to select between the functions in the list and map over the desired axis of x at the same time:

    def apply_func_obj(i, x_slice):
        return jax.lax.switch(i, functions_obj, x_slice)
    
    indices = jnp.arange(len(functions_obj)) 
    # Use vmap to apply the function element-wise
    results = jax.vmap(apply_func_obj, in_axes=(0, 0))(indices, x)