I am a JAX beginner and someone experienced with JAX told me that if we have repeated calls to a scan
/for
loop (e.g. when these are themselves wrapped by another for
loop), it might be better to leave the loop as a for
instead of converting it to a scan
because the for
loop is unrolled completely and only has the 1-time huge compilation cost while the scan
is not unrolled by default and even though its compilation cost will be small, the fact that it is rolled will mean that the cost of repeatedly running this loop will end up making the scan
more expensive than the for
. This did not strike me immediately when I started writing my code, but made sense upon thinking about it.
So, I tested this assumption using code based on the following pseudo-code (the full code is really long and I hope these relevant parts I provide here are easier to understand):
for i in range(num_train_steps): # Currently fixed to be a for loop
for j in range(num_env_steps): # Currently fixed to be a for loop
act()
def act():
for k in range(num_algo_iters): # Currently playing around with making this one either a scan or a for loop
jax.lax.scan(rollout_func) # Currently fixed to be a scan
The only loop in the above code that I tested switching between scan
and for
was the k
loop and then I varied the variable num_env_steps
to be 1, 100, 1000 and 10000 to see whether increasing the number of times the act()
(and thus the k
loop) was executed made a difference to the timing. (The testing was done with 5 iterations for the k for loop and 2 iterations for the innermost scan although these are variable in general, if that matters.)
The times taken for act()
for the different repeats were 1.5, 11.3,, 99.0, 956.2 seconds for the scan
version and 5.1, 14.5, 103.6, 972.7 seconds for the for
version. So the for
version never ended up faster for the number of repeats I tried.
So, now I am wondering if for any number of repeats (i.e. num_env_steps
), the unrolling of the for
actually makes the program faster than with scan
.
My questions:
num_env_steps
to 100k or 1 million make it faster or can we always just replace a for
with a scan
? I have this question because I wonder if I am trying to over-optimise my code by converting every for
to a scan
.unroll = True
for the scan
, would it then always be fine to replace all for
s with scan
s and expect speed-ups?for
and when to use scan
if I am only interested in such speed-ups?act
was jitted by the way.
scan
vs for
loop is essentially a tradeoff between compilation cost and runtime cost.
JAX unrolls Python control flow, meaning that a for
loop with 100 iterations leads to a single linear program with 100 copies of the loop body. The benefit of this is that it leaves the compiler free to optimize code across loop iterations, e.g. fusing operations between one iteration and the next; or noticing that one output is unused and eliding every computation in its graph. The downside is that compilation cost grows super-linearly with the size of the program, so for
loops with large loop bodies and/or many iterations can lead to very long compilation times.
With scan
or fori_loop
on the other hand, the looping logic is pushed into the HLO, and the loop body is only parsed and compiled once. This results in much more efficient compilation, but may leave some runtime performance on the table compared to a for
loop, because the compiler has fewer degrees of freedom to work with.
The best option will depend on the details of your program, and the relative importance of runtime and compile time costs in your particular application. Speaking very generally, though: for a smaller loop body with fewer iterations, for
loops are often the better choice. For a larger loop body with more iterations, scan
/ fori_loop
is likely better.
Note also that scan
has an unroll
parameter that gives you the ability to tune the tradeoff between these extremes: unroll=True
is effectively equivalent to a for
loop, while unroll=n
for 1 < n < N_iterations
effectively puts a small for
loop within each step of the larger scan
.