I have a function that has jax.lax.while_loop
. Now, I want to vmap
it. However, vmap
makes the execution time very slow compared to the original one.
I understand that in the case of lax.cond
, it is transformed into select
, which evaluates all branches and thus may decrease the computational speed.
Is a similar thing happening here? If so, what is the best practice to do
do xx while y is true
with vmap
?
A while_loop
under vmap
becomes a single while_loop
over a batched body_fun
and cond_fun
, meaning effectively that every loop in the batch executes for the same number of iterations. If different batches lead to vastly different iteration times, this can result in extra computation compared to executing individual while_loop
s in sequence.