I'm trying to vmap a function. My understanding of vmap is essentially anywhere I would write a ~for loop/list comprehension I should instead consider vmapping. I have a few points of confusion:
The reason I'm asking is that it seems like vmap, like JIT, runs into all sorts of ConcretizationTypeError and seems (not 100% clear yet) to need constant sized items for everything. I associate this behavior with any function I'm trying to Jit, but not necessarily any function I write in Jax.
Does vmap need fixed sizes for everything through the function(s) being vmapped?
yes – vmap
, like all JAX transformations, requires any arrays defined in the function to have static shapes.
Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).
No, vmap
does not jit
-compile a function by default, although you can always compose both if you wish (e.g. jit(vmap(f))
)
If vmap is jit-ing something, how would one use something like a static-arguments with vmap?
As mentioned, vmap
is unrelated to jit
, but an analogy of jit
static_argnums
is passing None
to in_axes
, which will keep the argument unmapped and therefore static within the transformation.
What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)