jax

JAX vmap JIT behind the scenes?


I'm trying to vmap a function. My understanding of vmap is essentially anywhere I would write a ~for loop/list comprehension I should instead consider vmapping. I have a few points of confusion:

  1. Does vmap need fixed sizes for everything through the function(s) being vmapped?
  2. Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).
  3. If vmap is jit-ing something, how would one use something like a static-arguments with vmap?
  4. What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)

The reason I'm asking is that it seems like vmap, like JIT, runs into all sorts of ConcretizationTypeError and seems (not 100% clear yet) to need constant sized items for everything. I associate this behavior with any function I'm trying to Jit, but not necessarily any function I write in Jax.


Solution

  • Does vmap need fixed sizes for everything through the function(s) being vmapped?

    yes – vmap, like all JAX transformations, requires any arrays defined in the function to have static shapes.

    Does vmap try to JIT my function behind the scenes? (Wondering bc. 1 is a behavior I expect from JIT, I didn't expect it from vmap but I don't really know vmap).

    No, vmap does not jit-compile a function by default, although you can always compose both if you wish (e.g. jit(vmap(f)))

    If vmap is jit-ing something, how would one use something like a static-arguments with vmap?

    As mentioned, vmap is unrelated to jit, but an analogy of jit static_argnums is passing None to in_axes, which will keep the argument unmapped and therefore static within the transformation.

    What is the best practice for dealing with ~extraneous information (eg if some outputs are sized a and some sized b, do you just make an array sized max(a,b) then ~ignore the extra values?)