pythonnumpyjaxcupy

Batched matrix multiplication with JAX on GPU faster with larger matrices


I'm trying to perform batched matrix multiplication with JAX on GPU, and noticed that it is ~3x faster to multiply shapes (1000, 1000, 3, 35) @ (1000, 1000, 35, 1) than it is to multiply (1000, 1000, 3, 25) @ (1000, 1000, 25, 1) with f64 and ~5x with f32.

What explains this difference, considering that on cpu neither JAX or NumPy show this behaviour, and on GPU CuPy doesn't show this behaviour?

I'm running this with JAX: 0.4.32 on an NVIDIA RTX A5000 (and get similar results on a Tesla T4), code to reproduce:

import numpy as np
import cupy as cp
from cupyx.profiler import benchmark
from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

rng = np.random.default_rng()

x = np.arange(5, 55, 5)

GPU timings:

dtype = cp.float64
timings_cp = []
for i in range(5, 55, 5):
    a = cp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
    b = cp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
    timings_cp.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))

dtype = jnp.float64
timings_jax_gpu = []
with jax.default_device(jax.devices('gpu')[0]):
    for i in range(5, 55, 5):
        a = jnp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
        b = jnp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
        func = jax.jit(lambda a, b: a@b)
        timings_jax_gpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))

plt.figure()
plt.plot(x, [i.gpu_times.mean() for i in timings_cp], label="CuPy")
plt.plot(x, [i.gpu_times.mean() for i in timings_jax_gpu], label="JAX GPU")
plt.legend()

enter image description here

Timings with those specific shapes:

dtype = jnp.float64
with jax.default_device(jax.devices('gpu')[0]):
    a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
    b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
    func = jax.jit(lambda a, b: a@b)
    print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())

    a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
    b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
    print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())

Gives

f64:
0.01453789699935913
0.004859122595310211

f32:

0.005860503035545349
0.001209742688536644

CPU timings:

timings_np = []
for i in range(5, 55, 5):
    a = rng.random((1000, 1000, 3, i))
    b = rng.random((1000, 1000, i, 1))
    timings_np.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))

timings_jax_cpu = []
with jax.default_device(jax.devices('cpu')[0]):
    for i in range(5, 55, 5):
        a = jnp.array(rng.random((1000, 1000, 3, i)))
        b = jnp.array(rng.random((1000, 1000, i, 1)))
        func = jax.jit(lambda a, b: a@b)
        timings_jax_cpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))

plt.figure()
plt.plot(x, [i.cpu_times.mean() for i in timings_np], label="NumPy")
plt.plot(x, [i.cpu_times.mean() for i in timings_jax_cpu], label="JAX CPU")
plt.legend()

enter image description here


Solution

  • The difference seems to come from the compiler emitting a kLoop fusion for smaller sizes, and a kInput fusion for larger sizes. You can read about the effect of these in this source comment: https://github.com/openxla/xla/blob/e6b6e61b29cc439350a6ad2f9d39535cb06011e5/xla/hlo/ir/hlo_instruction.h#L639-L656

    The compiler likely uses some heuristic to choose between the two, and it appears that this heuristic is suboptimal at the boundary for your particular problem. You can see this by outputting the compiled HLO for your operation:

    a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
    b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
    print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
    
    HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f64[1000,1000,3,25]{3,2,1,0}, f64[1000,1000,25,1]{3,2,1,0})->f64[1000,1000,3,1]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="a02cbfe0fda9d44e2bd23462363b6cc0"}
    
    %scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
      %scalar_rhs = f64[] parameter(1)
      %scalar_lhs = f64[] parameter(0)
      ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
    }
    
    %fused_reduce (param_0.7: f64[1000,1000,3,25], param_1.6: f64[1000,1000,25,1]) -> f64[1000,1000,3] {
      %param_0.7 = f64[1000,1000,3,25]{3,2,1,0} parameter(0)
      %param_1.6 = f64[1000,1000,25,1]{3,2,1,0} parameter(1)
      %bitcast.28.5 = f64[1000,1000,25]{2,1,0} bitcast(f64[1000,1000,25,1]{3,2,1,0} %param_1.6)
      %broadcast.2.5 = f64[1000,1000,3,25]{3,2,1,0} broadcast(f64[1000,1000,25]{2,1,0} %bitcast.28.5), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
      %multiply.2.3 = f64[1000,1000,3,25]{3,2,1,0} multiply(f64[1000,1000,3,25]{3,2,1,0} %param_0.7, f64[1000,1000,3,25]{3,2,1,0} %broadcast.2.5)
      %constant_4 = f64[] constant(0)
      ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,25]{3,2,1,0} %multiply.2.3, f64[] %constant_4), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
    }
    
    ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,25], Arg_1.2.0: f64[1000,1000,25,1]) -> f64[1000,1000,3,1] {
      %Arg_1.2.0 = f64[1000,1000,25,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
      %Arg_0.1.0 = f64[1000,1000,3,25]{3,2,1,0} parameter(0), metadata={op_name="a"}
      %loop_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,25]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,25,1]{3,2,1,0} %Arg_1.2.0), kind=kLoop, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
      ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %loop_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
    }
    
    a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
    b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
    print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
    
    %scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
      %scalar_rhs = f64[] parameter(1)
      %scalar_lhs = f64[] parameter(0)
      ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
    }
    
    %fused_reduce (param_0.5: f64[1000,1000,3,35], param_1.2: f64[1000,1000,35,1]) -> f64[1000,1000,3] {
      %param_0.5 = f64[1000,1000,3,35]{3,2,1,0} parameter(0)
      %param_1.2 = f64[1000,1000,35,1]{3,2,1,0} parameter(1)
      %bitcast.28.3 = f64[1000,1000,35]{2,1,0} bitcast(f64[1000,1000,35,1]{3,2,1,0} %param_1.2)
      %broadcast.2.3 = f64[1000,1000,3,35]{3,2,1,0} broadcast(f64[1000,1000,35]{2,1,0} %bitcast.28.3), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
      %multiply.2.1 = f64[1000,1000,3,35]{3,2,1,0} multiply(f64[1000,1000,3,35]{3,2,1,0} %param_0.5, f64[1000,1000,3,35]{3,2,1,0} %broadcast.2.3)
      %constant_3 = f64[] constant(0)
      ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,35]{3,2,1,0} %multiply.2.1, f64[] %constant_3), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
    }
    
    ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,35], Arg_1.2.0: f64[1000,1000,35,1]) -> f64[1000,1000,3,1] {
      %Arg_1.2.0 = f64[1000,1000,35,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
      %Arg_0.1.0 = f64[1000,1000,3,35]{3,2,1,0} parameter(0), metadata={op_name="a"}
      %input_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,35]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,35,1]{3,2,1,0} %Arg_1.2.0), kind=kInput, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
      ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %input_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
    }
    

    Here's a script to observe this compiler decision with respect to size:

    for size in range(10, 55, 5):
      a = jnp.array(rng.random((1000, 1000, 3, size)), dtype=dtype)
      b = jnp.array(rng.random((1000, 1000, size, 1)), dtype=dtype)
      hlo_text = jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text()
      print(f"{size=} {'kLoop' in hlo_text=}")
    
    size=10 'kLoop' in hlo_text=True
    size=15 'kLoop' in hlo_text=True
    size=20 'kLoop' in hlo_text=True
    size=25 'kLoop' in hlo_text=True
    size=30 'kLoop' in hlo_text=True
    size=35 'kLoop' in hlo_text=False
    size=40 'kLoop' in hlo_text=False
    size=45 'kLoop' in hlo_text=False
    size=50 'kLoop' in hlo_text=False
    

    I don't have any suggestion beyond perhaps reporting this at https://github.com/openxla/xla; it may be that the compiler heuristic for choosing to emit kLoop vs. kInput needs some additional logic.