jax

Jax vmapping while loop


I have a function that has jax.lax.while_loop. Now, I want to vmap it. However, vmap makes the execution time very slow compared to the original one.

I understand that in the case of lax.cond, it is transformed into select, which evaluates all branches and thus may decrease the computational speed.

Is a similar thing happening here? If so, what is the best practice to do do xx while y is true with vmap?


Solution

  • A while_loop under vmap becomes a single while_loop over a batched body_fun and cond_fun, meaning effectively that every loop in the batch executes for the same number of iterations. If different batches lead to vastly different iteration times, this can result in extra computation compared to executing individual while_loops in sequence.