pythonjax

How to select between using a `jax.lax.scan` vs a `for` loop when using JAX?


I am a JAX beginner and someone experienced with JAX told me that if we have repeated calls to a scan/for loop (e.g. when these are themselves wrapped by another for loop), it might be better to leave the loop as a for instead of converting it to a scan because the for loop is unrolled completely and only has the 1-time huge compilation cost while the scan is not unrolled by default and even though its compilation cost will be small, the fact that it is rolled will mean that the cost of repeatedly running this loop will end up making the scan more expensive than the for. This did not strike me immediately when I started writing my code, but made sense upon thinking about it.

So, I tested this assumption using code based on the following pseudo-code (the full code is really long and I hope these relevant parts I provide here are easier to understand):

for i in range(num_train_steps):  # Currently fixed to be a for loop
  for j in range(num_env_steps):  # Currently fixed to be a for loop
    act()

def act():
  for k in range(num_algo_iters):  # Currently playing around with making this one either a scan or a for loop
    jax.lax.scan(rollout_func)  # Currently fixed to be a scan

The only loop in the above code that I tested switching between scan and for was the k loop and then I varied the variable num_env_steps to be 1, 100, 1000 and 10000 to see whether increasing the number of times the act() (and thus the k loop) was executed made a difference to the timing. (The testing was done with 5 iterations for the k for loop and 2 iterations for the innermost scan although these are variable in general, if that matters.) The times taken for act() for the different repeats were 1.5, 11.3,, 99.0, 956.2 seconds for the scan version and 5.1, 14.5, 103.6, 972.7 seconds for the for version. So the for version never ended up faster for the number of repeats I tried.

So, now I am wondering if for any number of repeats (i.e. num_env_steps), the unrolling of the for actually makes the program faster than with scan. My questions:

  1. Would maybe increasing the repeats even more by setting num_env_steps to 100k or 1 million make it faster or can we always just replace a for with a scan? I have this question because I wonder if I am trying to over-optimise my code by converting every for to a scan.
  2. If I set unroll = True for the scan, would it then always be fine to replace all fors with scans and expect speed-ups?
  3. Is there a rule of thumb that can help me decide when to use for and when to use scan if I am only interested in such speed-ups?

act was jitted by the way.


Solution

  • scan vs for loop is essentially a tradeoff between compilation cost and runtime cost.

    JAX unrolls Python control flow, meaning that a for loop with 100 iterations leads to a single linear program with 100 copies of the loop body. The benefit of this is that it leaves the compiler free to optimize code across loop iterations, e.g. fusing operations between one iteration and the next; or noticing that one output is unused and eliding every computation in its graph. The downside is that compilation cost grows super-linearly with the size of the program, so for loops with large loop bodies and/or many iterations can lead to very long compilation times.

    With scan or fori_loop on the other hand, the looping logic is pushed into the HLO, and the loop body is only parsed and compiled once. This results in much more efficient compilation, but may leave some runtime performance on the table compared to a for loop, because the compiler has fewer degrees of freedom to work with.

    The best option will depend on the details of your program, and the relative importance of runtime and compile time costs in your particular application. Speaking very generally, though: for a smaller loop body with fewer iterations, for loops are often the better choice. For a larger loop body with more iterations, scan / fori_loop is likely better.

    Note also that scan has an unroll parameter that gives you the ability to tune the tradeoff between these extremes: unroll=True is effectively equivalent to a for loop, while unroll=n for 1 < n < N_iterations effectively puts a small for loop within each step of the larger scan.