pythonarraysnumpyperformancejax

Why is array manipulation in JAX much slower?


I'm working on converting a transformation-heavy numerical pipeline from NumPy to JAX to take advantage of JIT acceleration. However, I’ve found that some basic operations like broadcast_to and moveaxis are significantly slower in JAX—even without JIT—compared to NumPy, and even for large batch sizes like 3,000,000 where I would expect JAX to be much quicker.

### Benchmark: moveaxis + broadcast_to ###
NumPy: moveaxis + broadcast_to → 0.000116 s
JAX: moveaxis + broadcast_to → 0.204249 s
JAX JIT: moveaxis + broadcast_to → 0.054713 s

### Benchmark: broadcast_to only ###
NumPy: broadcast_to → 0.000059 s
JAX: broadcast_to → 0.062167 s
JAX JIT: broadcast_to → 0.057625 s

Am I doing something wrong? Are there better ways of performing these kind of manipulations?

Here's a minimal benchmark ChatGPT generated, comparing broadcast_to and moveaxis in NumPy, JAX, and JAX with JIT:

import timeit

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit

# Base transformation matrix
M_np = np.array([[1, 0, 0, 0.5],
                 [0, 1, 0, 0],
                 [0, 0, 1, 0],
                 [0, 0, 0, 1]])

M_jax = jnp.array(M_np)

# Batch size
n = 1_000_000

print("### Benchmark: moveaxis + broadcast_to ###")

# NumPy
t_numpy = timeit.timeit(
    lambda: np.moveaxis(np.broadcast_to(M_np[:, :, None], (4, 4, n)), 2, 0),
    number=10
)
print(f"NumPy: moveaxis + broadcast_to → {t_numpy:.6f} s")

# JAX
t_jax = timeit.timeit(
    lambda: jnp.moveaxis(jnp.broadcast_to(M_jax[:, :, None], (4, 4, n)), 2, 0).block_until_ready(),
    number=10
)
print(f"JAX: moveaxis + broadcast_to → {t_jax:.6f} s")

# JAX JIT
@jit
def broadcast_and_move_jax(M):
    return jnp.moveaxis(jnp.broadcast_to(M[:, :, None], (4, 4, n)), 2, 0)

# Warm-up
broadcast_and_move_jax(M_jax).block_until_ready()

t_jit = timeit.timeit(
    lambda: broadcast_and_move_jax(M_jax).block_until_ready(),
    number=10
)
print(f"JAX JIT: moveaxis + broadcast_to → {t_jit:.6f} s")

print("\n### Benchmark: broadcast_to only ###")

# NumPy
t_numpy_b = timeit.timeit(
    lambda: np.broadcast_to(M_np[:, :, None], (4, 4, n)),
    number=10
)
print(f"NumPy: broadcast_to → {t_numpy_b:.6f} s")

# JAX
t_jax_b = timeit.timeit(
    lambda: jnp.broadcast_to(M_jax[:, :, None], (4, 4, n)).block_until_ready(),
    number=10
)
print(f"JAX: broadcast_to → {t_jax_b:.6f} s")

# JAX JIT
@jit
def broadcast_only_jax(M):
    return jnp.broadcast_to(M[:, :, None], (4, 4, n))

broadcast_only_jax(M_jax).block_until_ready()

t_jit_b = timeit.timeit(
    lambda: broadcast_only_jax(M_jax).block_until_ready(),
    number=10
)
print(f"JAX JIT: broadcast_to → {t_jit_b:.6f} s")



Solution

  • There are a couple things happening here that come from the different execution models of NumPy and JAX.

    First, NumPy operations like broadcasting, transposing, reshaping, slicing, etc. typically return views of the original buffer. In JAX, it is not possible for two array objects to share memory, and so the equivalent operations return copies. I suspect this is the largest contribution to the timing difference here.

    Second, NumPy tends to have very fast dispatch time for individual operations. JAX has much slower dispatch time for individual operations, and this can become important when the operation itself is very cheap (like "return a view of the array with different strides/shape")

    You might wonder given these points how JAX could ever be faster than NumPy. The key is JIT compilation of sequences of operations: within JIT-compiled code, sequences of operations are fused so that the output of each individual operation need not be allocated (or indeed, need not even exist at all as a buffer of intermediate values). Additionally, for JIT compiled sequences of operations the dispatch overhead is paid only once for the whole program. Compare this to NumPy where there's no way to fuse operations or to avoid paying the dispatch cost of each and every operation.

    So in microbenchmarks like this, you can expect JAX to be slower than NumPy. But for real-world sequences of operations wrapped in JIT, you should often find that JAX is faster, even when executing on CPU.

    This type of question comes up enough that there's a section devoted to it in JAX's FAQ: FAQ: is JAX faster than NumPy?


    Answering the followup question:

    Is the statement "In JAX, it is not possible for two array objects to share memory, and so the equivalent operations return copies", within a jitted environment?

    This question is not really well-formulated, because in a jitted environment, array objects do not necessarily correspond to buffers of values. Let's make this more concrete with a simple example:

    import jax
    
    @jax.jit
    def f(x):
      y = x[::2]
      return y.sum()
    

    You might ask: in this program, is y a copy or a view of x? The answer is neither, because y is never explicitly created. Instead, JIT fuses the slice and the sum into a single operation: the array x is the input, and the array y.sum() is the output, and the intermediate array y is never actually created.

    You can see this by printing the compiled HLO for this function:

    x = jax.numpy.arange(10)
    print(f.lower(x).compile().as_text())
    
    HloModule jit_f, is_scheduled=true, entry_computation_layout={(s32[10]{0})->s32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}
    
    %region_0.9 (Arg_0.10: s32[], Arg_1.11: s32[]) -> s32[] {
      %Arg_0.10 = s32[] parameter(0), metadata={op_name="jit(f)/jit(main)/reduce_sum"}
      %Arg_1.11 = s32[] parameter(1), metadata={op_name="jit(f)/jit(main)/reduce_sum"}
      ROOT %add.12 = s32[] add(s32[] %Arg_0.10, s32[] %Arg_1.11), metadata={op_name="jit(f)/jit(main)/reduce_sum" source_file="<ipython-input-1-9ea6c70efef5>" source_line=5}
    }
    
    %fused_computation (param_0.2: s32[10]) -> s32[] {
      %param_0.2 = s32[10]{0} parameter(0)
      %iota.0 = s32[5]{0} iota(), iota_dimension=0, metadata={op_name="jit(f)/jit(main)/iota" source_file="<ipython-input-1-9ea6c70efef5>" source_line=4}
      %constant.1 = s32[] constant(2)
      %broadcast.0 = s32[5]{0} broadcast(s32[] %constant.1), dimensions={}
      %multiply.0 = s32[5]{0} multiply(s32[5]{0} %iota.0, s32[5]{0} %broadcast.0), metadata={op_name="jit(f)/jit(main)/mul" source_file="<ipython-input-1-9ea6c70efef5>" source_line=4}
      %bitcast.1 = s32[5,1]{1,0} bitcast(s32[5]{0} %multiply.0), metadata={op_name="jit(f)/jit(main)/mul" source_file="<ipython-input-1-9ea6c70efef5>" source_line=4}
      %gather.0 = s32[5]{0} gather(s32[10]{0} %param_0.2, s32[5,1]{1,0} %bitcast.1), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, indices_are_sorted=true, metadata={op_name="jit(f)/jit(main)/gather" source_file="<ipython-input-1-9ea6c70efef5>" source_line=4}
      %constant.0 = s32[] constant(0)
      ROOT %reduce.0 = s32[] reduce(s32[5]{0} %gather.0, s32[] %constant.0), dimensions={0}, to_apply=%region_0.9, metadata={op_name="jit(f)/jit(main)/reduce_sum" source_file="<ipython-input-1-9ea6c70efef5>" source_line=5}
    }
    
    ENTRY %main.14 (Arg_0.1: s32[10]) -> s32[] {
      %Arg_0.1 = s32[10]{0} parameter(0), metadata={op_name="x"}
      ROOT %gather_reduce_fusion = s32[] fusion(s32[10]{0} %Arg_0.1), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(f)/jit(main)/reduce_sum" source_file="<ipython-input-1-9ea6c70efef5>" source_line=5}
    }
    

    The output is complicated, but the main thing to look at here is the ENTRY %main section, which is the "main" program generated by compilation. It consists of two steps: %Arg0.1 identifies the input argument, and ROOT %gather_reduce_fusion is essentially a single compiled kernel that sums every second element of the input. No intermediate arrays are generated. The blocks above this (e.g. the %fused_computation (param_0.2: s32[10]) -> s32[] definition) give you information about what operations are done within this kernel, but represent a single fused operation.

    Notice that the sliced array represented by y in the Python code never actually appears in the main function block, so questions about its memory layout cannot be answered except by saying "y doesn't exist in the compiled program".