I'm trying to perform batched matrix multiplication with JAX on GPU, and noticed that it is ~3x faster to multiply shapes (1000, 1000, 3, 35) @ (1000, 1000, 35, 1) than it is to multiply (1000, 1000, 3, 25) @ (1000, 1000, 25, 1) with f64 and ~5x with f32.
What explains this difference, considering that on cpu neither JAX or NumPy show this behaviour, and on GPU CuPy doesn't show this behaviour?
I'm running this with JAX: 0.4.32 on an NVIDIA RTX A5000 (and get similar results on a Tesla T4), code to reproduce:
import numpy as np
import cupy as cp
from cupyx.profiler import benchmark
from jax import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
rng = np.random.default_rng()
x = np.arange(5, 55, 5)
GPU timings:
dtype = cp.float64
timings_cp = []
for i in range(5, 55, 5):
a = cp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
b = cp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
timings_cp.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))
dtype = jnp.float64
timings_jax_gpu = []
with jax.default_device(jax.devices('gpu')[0]):
for i in range(5, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, i)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, i, 1)), dtype=dtype)
func = jax.jit(lambda a, b: a@b)
timings_jax_gpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))
plt.figure()
plt.plot(x, [i.gpu_times.mean() for i in timings_cp], label="CuPy")
plt.plot(x, [i.gpu_times.mean() for i in timings_jax_gpu], label="JAX GPU")
plt.legend()
Timings with those specific shapes:
dtype = jnp.float64
with jax.default_device(jax.devices('gpu')[0]):
a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
func = jax.jit(lambda a, b: a@b)
print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())
a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
print(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=1000, n_warmup=10).gpu_times.mean())
Gives
f64:
0.01453789699935913
0.004859122595310211
f32:
0.005860503035545349
0.001209742688536644
CPU timings:
timings_np = []
for i in range(5, 55, 5):
a = rng.random((1000, 1000, 3, i))
b = rng.random((1000, 1000, i, 1))
timings_np.append(benchmark(lambda a, b: a@b, (a, b), n_repeat=10, n_warmup=10))
timings_jax_cpu = []
with jax.default_device(jax.devices('cpu')[0]):
for i in range(5, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, i)))
b = jnp.array(rng.random((1000, 1000, i, 1)))
func = jax.jit(lambda a, b: a@b)
timings_jax_cpu.append(benchmark(lambda a, b: func(a, b).block_until_ready(), (a, b), n_repeat=10, n_warmup=10))
plt.figure()
plt.plot(x, [i.cpu_times.mean() for i in timings_np], label="NumPy")
plt.plot(x, [i.cpu_times.mean() for i in timings_jax_cpu], label="JAX CPU")
plt.legend()
The difference seems to come from the compiler emitting a kLoop
fusion for smaller sizes, and a kInput
fusion for larger sizes. You can read about the effect of these in this source comment: https://github.com/openxla/xla/blob/e6b6e61b29cc439350a6ad2f9d39535cb06011e5/xla/hlo/ir/hlo_instruction.h#L639-L656
The compiler likely uses some heuristic to choose between the two, and it appears that this heuristic is suboptimal at the boundary for your particular problem. You can see this by outputting the compiled HLO for your operation:
a = jnp.array(rng.random((1000, 1000, 3, 25)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 25, 1)), dtype=dtype)
print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f64[1000,1000,3,25]{3,2,1,0}, f64[1000,1000,25,1]{3,2,1,0})->f64[1000,1000,3,1]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs="a02cbfe0fda9d44e2bd23462363b6cc0"}
%scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
%scalar_rhs = f64[] parameter(1)
%scalar_lhs = f64[] parameter(0)
ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
}
%fused_reduce (param_0.7: f64[1000,1000,3,25], param_1.6: f64[1000,1000,25,1]) -> f64[1000,1000,3] {
%param_0.7 = f64[1000,1000,3,25]{3,2,1,0} parameter(0)
%param_1.6 = f64[1000,1000,25,1]{3,2,1,0} parameter(1)
%bitcast.28.5 = f64[1000,1000,25]{2,1,0} bitcast(f64[1000,1000,25,1]{3,2,1,0} %param_1.6)
%broadcast.2.5 = f64[1000,1000,3,25]{3,2,1,0} broadcast(f64[1000,1000,25]{2,1,0} %bitcast.28.5), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
%multiply.2.3 = f64[1000,1000,3,25]{3,2,1,0} multiply(f64[1000,1000,3,25]{3,2,1,0} %param_0.7, f64[1000,1000,3,25]{3,2,1,0} %broadcast.2.5)
%constant_4 = f64[] constant(0)
ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,25]{3,2,1,0} %multiply.2.3, f64[] %constant_4), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
}
ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,25], Arg_1.2.0: f64[1000,1000,25,1]) -> f64[1000,1000,3,1] {
%Arg_1.2.0 = f64[1000,1000,25,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
%Arg_0.1.0 = f64[1000,1000,3,25]{3,2,1,0} parameter(0), metadata={op_name="a"}
%loop_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,25]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,25,1]{3,2,1,0} %Arg_1.2.0), kind=kLoop, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %loop_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-4-68f2557428ff>" source_line=3}
}
a = jnp.array(rng.random((1000, 1000, 3, 35)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, 35, 1)), dtype=dtype)
print(jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text())
%scalar_add_computation (scalar_lhs: f64[], scalar_rhs: f64[]) -> f64[] {
%scalar_rhs = f64[] parameter(1)
%scalar_lhs = f64[] parameter(0)
ROOT %add.2 = f64[] add(f64[] %scalar_lhs, f64[] %scalar_rhs)
}
%fused_reduce (param_0.5: f64[1000,1000,3,35], param_1.2: f64[1000,1000,35,1]) -> f64[1000,1000,3] {
%param_0.5 = f64[1000,1000,3,35]{3,2,1,0} parameter(0)
%param_1.2 = f64[1000,1000,35,1]{3,2,1,0} parameter(1)
%bitcast.28.3 = f64[1000,1000,35]{2,1,0} bitcast(f64[1000,1000,35,1]{3,2,1,0} %param_1.2)
%broadcast.2.3 = f64[1000,1000,3,35]{3,2,1,0} broadcast(f64[1000,1000,35]{2,1,0} %bitcast.28.3), dimensions={0,1,3}, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
%multiply.2.1 = f64[1000,1000,3,35]{3,2,1,0} multiply(f64[1000,1000,3,35]{3,2,1,0} %param_0.5, f64[1000,1000,3,35]{3,2,1,0} %broadcast.2.3)
%constant_3 = f64[] constant(0)
ROOT %reduce.2 = f64[1000,1000,3]{2,1,0} reduce(f64[1000,1000,3,35]{3,2,1,0} %multiply.2.1, f64[] %constant_3), dimensions={3}, to_apply=%scalar_add_computation, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
}
ENTRY %main.4 (Arg_0.1.0: f64[1000,1000,3,35], Arg_1.2.0: f64[1000,1000,35,1]) -> f64[1000,1000,3,1] {
%Arg_1.2.0 = f64[1000,1000,35,1]{3,2,1,0} parameter(1), metadata={op_name="b"}
%Arg_0.1.0 = f64[1000,1000,3,35]{3,2,1,0} parameter(0), metadata={op_name="a"}
%input_reduce_fusion = f64[1000,1000,3]{2,1,0} fusion(f64[1000,1000,3,35]{3,2,1,0} %Arg_0.1.0, f64[1000,1000,35,1]{3,2,1,0} %Arg_1.2.0), kind=kInput, calls=%fused_reduce, metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
ROOT %bitcast.1.0 = f64[1000,1000,3,1]{3,2,1,0} bitcast(f64[1000,1000,3]{2,1,0} %input_reduce_fusion), metadata={op_name="jit(<lambda>)/jit(main)/dot_general" source_file="<ipython-input-3-eb3ac06eae7a>" source_line=4}
}
Here's a script to observe this compiler decision with respect to size:
for size in range(10, 55, 5):
a = jnp.array(rng.random((1000, 1000, 3, size)), dtype=dtype)
b = jnp.array(rng.random((1000, 1000, size, 1)), dtype=dtype)
hlo_text = jax.jit(lambda a, b: a @ b).lower(a, b).compile().as_text()
print(f"{size=} {'kLoop' in hlo_text=}")
size=10 'kLoop' in hlo_text=True
size=15 'kLoop' in hlo_text=True
size=20 'kLoop' in hlo_text=True
size=25 'kLoop' in hlo_text=True
size=30 'kLoop' in hlo_text=True
size=35 'kLoop' in hlo_text=False
size=40 'kLoop' in hlo_text=False
size=45 'kLoop' in hlo_text=False
size=50 'kLoop' in hlo_text=False
I don't have any suggestion beyond perhaps reporting this at https://github.com/openxla/xla; it may be that the compiler heuristic for choosing to emit kLoop
vs. kInput
needs some additional logic.