Why are the "batch" axes always the leading axes in NumPy? I designed all my packages to use the trailing axes as batch axes because this seems more natural to me. Now I'm thinking about switching to NumPy's convention - just to make things more intuitive for NumPy users. Any ideas on that?
Performance-wise, this could be a really bad idea:
import numpy as np
np.random.seed(6512)
a = np.random.rand(50000, 8, 3, 3)
np.random.seed(85742)
b = np.random.rand(50000, 8, 3, 3)
c = a @ b
# 19.8 ms ± 543 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
d = np.einsum("...ik,...kj->...ij", a, b)
# 84.1 ms ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# now use the trailing axes (ensure C-contiguous arrays for transposed data)
A = np.ascontiguousarray(np.transpose(a, [2, 3, 0, 1])) # A_ijab
B = np.ascontiguousarray(np.transpose(b, [2, 3, 0, 1])) # B_ijab
C = (B.T @ A.T).T # (C^T)_baji = B_bajk A_baki -> C_ijab
# 16.9 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
D = np.einsum("ik...,kj...->ij...", A, B)
# 17.2 ms ± 842 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
assert np.allclose(c, d)
assert np.allclose(C, D)
assert np.allclose(np.transpose(D, [2, 3, 0, 1]), d)
assert np.allclose(np.transpose(C, [2, 3, 0, 1]), c)
Or more complicated einsums:
# crossed-dyadic product
# ----------------------
E = np.einsum("ik...,jl...->ijkl...", A, B)
# 76.5 ms ± 2.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
e = np.einsum("...ik,...jl->...ijkl", a, b)
# 207 ms ± 3.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
assert np.allclose(np.transpose(E, [4, 5, 0, 1, 2, 3]), e)
Numpy is written in C and use the C convention for arrays, that is the row-major array ordering. Thus, applying the operation on the last axis (i.e. right-most ones, and also the most contiguous ones), is more efficient for CPU caches. Transposing the array significantly increases the pressure on the RAM for large arrays so it often results in a much slower computation (the RAM is often the limiting factor for Numpy operations).
That being said, in your case, Numpy is clearly not optimized for 3x3 matrices. The overhead of the internal generic Numpy iterators (enabling broadcasting) is so huge in this case that the computation is bound by them. Most BLAS libraries are also not optimized for such extremely small matrices. Some linear algebra libraries provide batch operation for this (e.g. AFAIK CuBLAS does that). However, Numpy does not support them yet.
Modern mainstream CPUs can compute 3x3 matrix multiplications in only few nanoseconds, so the overhead of generic codes is too big to compute them efficient. To get a fast implementation, you need to write a compiled code supporting specifically fixed-sized 3x3 matrix. Compilers can then generate efficient instructions chosen for this specific case. Hand-written assembly codes (or compiler SIMD intrinsics) can certainly be significantly faster for this use-case, but they are hard to write, to maintain, and also bug-prone. The good solution is to use Cython (with memory-views and the right compilation flags) or even Numba in this case (if possible with the fast-math flag). You can find an example of Numba code used to solve a similar problem here.