I've ran into an issue with Jax that will make me rewrite an entire 20000-line application if I don't solve it.
I have a non-ML application which relies on pytrees to store data, and the pytrees are deep - about 6-7 layers of data storage (class1 stores class2, and that stores an array of class3 etc.)
I've used python lists to store pytrees and hoped to vmap over them, but turns out jax can't vmap over lists.
(So one solution is to rewrite literally every single dataclass to be a structured array and work from there, possibly putting all 6-7 layers of data into one mega-array)
Is there a way to avoid the rewrite? Is there a way to store pytree classes in a vmappable state so that everything works as before?
I have my classes marked with flax.struct.dataclass if that helps.
jax.vmap
is designed to work with a struct-of-arrays pattern, and it sounds like you have an array-of-structs pattern. From your description, it sounds like you have a sequence of nested structs that look something like this:
import jax
import jax.numpy as jnp
from flax.struct import dataclass
@dataclass
class Params:
x: jax.Array
y: jax.Array
@dataclass
class AllParams:
p: list[Params]
params_list = [AllParams([Params(4, 2), Params(4, 3)]),
AllParams([Params(3, 5), Params(2, 4)]),
AllParams([Params(3, 2), Params(6, 3)])]
Then you have a function that you want to apply to each element of the list; something like this:
def some_func(params):
a, b = params.p
return a.x * b.y - b.x * a.y
[some_func(params) for params in params_list]
[4, 2, -3]
But as you found, if you try to do this with vmap
, you get an error:
jax.vmap(some_func)(params_list)
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
The issue is that vmap
operates separately over each entry of the list or pytree you pass to it, not over the elements of the list.
To address this, you can often transform your data structure from an array-of-structs into a struct-of-arrays, and then apply vmap
over this. For example:
params_array = jax.tree.map(lambda *vals: jnp.array(vals), *params_list)
print(params_array)
AllParams(p=[ Params(x=Array([4, 3, 3], dtype=int32), y=Array([2, 5, 2], dtype=int32)), Params(x=Array([4, 2, 6], dtype=int32), y=Array([3, 4, 3], dtype=int32)) ])
Notice that rather than a list of structures, this is now a single structure with the batching pushed all the way down to the leaves. This is the "struct-of-arrays" pattern that vmap
is designed to work with, and so vmap
will work correctly:
jax.vmap(some_func)(params_array)
Array([ 4, 2, -3], dtype=int32)
Now, this assumes that every dataclass in your list has identical structure: if not, then vmap
will not be applicable, because by design it must map over computations with identical structure.