Suppose I have two arrays:
import numpy as np
a = np.random.randn(32, 6, 6, 20, 64, 3, 3)
b = np.random.randn(20, 128, 64, 3, 3)
and want to sum over the last 3 axes, and keep the shared axis. The output dimension should be (32,6,6,20,128)
. Notice here the axis with 20 is shared in both a
and b
. Let's call this axis the "group" axis.
I have two methods for this task:
The first one is just a simple einsum
:
def method1(a, b):
return np.einsum('NHWgihw, goihw -> NHWgo', a, b, optimize=True) # output shape:(32,6,6,20,128)
In the second method I loop through group dimension and use einsum
/tensordot
to compute the result for each group dimension, then stack the results:
def method2(a, b):
result = []
for g in range(b.shape[0]): # loop through each group dimension
# result.append(np.tensordot(a[..., g, :, :, :], b[g, ...], axes=((-3,-2,-1),(-3,-2,-1))))
result.append(np.einsum('NHWihw, oihw -> NHWo', a[..., g, :, :, :], b[g, ...], optimize=True)) # output shape:(32,6,6,128)
return np.stack(result, axis=-2) # output shape:(32,6,6,20,128)
here's the timing for both methods in my jupyter notebook:
we can see the second method with a loop is faster than the first method.
My question is:
Thanks for any help!
As pointed out by @Murali in the comments, method1
is not very efficient because it does not succeed to use a BLAS calls as opposed to method2
which does. In fact, np.einsum
is quite good in method1
since it compute the result sequentially while method2
mostly runs in parallel thanks to OpenBLAS (used by Numpy on most machines). That being said, method2
is sub-optimal since it does not fully use the available cores (parts of the computation are done sequentially) and appear not to use the cache efficiently. On my 6-core machine, it barely use 50% of all the cores.
One solution to speed up this computation is to write an highly-optimized Numba parallel code for this.
First of all, a semi-naive implementation is to use many for loops to compute the Einstein summation and reshape the input/output arrays so Numba can better optimize the code (eg. unrolling, use of SIMD instructions). Here is the result:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])')
def compute(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.empty((sN*sH*sW, sg, so), dtype=np.float64)
for NHW in range(sN*sH*sW):
for g in range(sg):
for o in range(so):
s = 0.0
# Reduction
for ihw in range(si*sh*sw):
s += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW, g, o] = s
return out.reshape((sN, sH, sW, sg, so))
Note that the input array are assumed to be contiguous. If this is not the case, please consider performing a copy (which is cheap compared to the computation).
While the above code works, it is far from being efficient. Here are some improvements that can be performed:
NHW
loop in parallel;fastmath=True
. This flag is unsafe if the input data contains special values like NaN or +inf/-inf. However, this flag help compiler to generate a much faster code using SIMD instructions (this is not possible otherwise since IEEE-754 floating-point operations are not associative);NHW
-based loop and g
-based loop results in better performance since it improves cache-locality (rb
is more likely to fit in the last-level cache of mainstream CPUs whereas it would likely in fetched from the RAM otherwise);o
-based loop so rb
can almost fully be read from lower-level caches (eg. L1 or L2).All these improvements except the last one are implemented in the following code:
@nb.njit('float64[:,:,:,:,::1](float64[:,:,:,:,:,:,::1], float64[:,:,:,:,::1])', parallel=True, fastmath=True)
def method3(a, b):
sN, sH, sW, sg, si, sh, sw = a.shape
so = b.shape[1]
assert b.shape == (sg, so, si, sh, sw)
ra = a.reshape(sN*sH*sW, sg, si*sh*sw)
rb = b.reshape(sg, so, si*sh*sw)
out = np.zeros((sN*sH*sW, sg, so), dtype=np.float64)
for g in range(sg):
for k in nb.prange((sN*sH*sW)//2):
NHW = k*2
so_vect_max = (so // 4) * 4
for o in range(0, so_vect_max, 4):
s00 = s01 = s02 = s03 = s10 = s11 = s12 = s13 = 0.0
# Useful since Numba does not optimize well the following loop otherwise
ra_row0 = ra[NHW+0, g, :]
ra_row1 = ra[NHW+1, g, :]
rb_row0 = rb[g, o+0, :]
rb_row1 = rb[g, o+1, :]
rb_row2 = rb[g, o+2, :]
rb_row3 = rb[g, o+3, :]
# Highly-optimized reduction using register blocking
for ihw in range(si*sh*sw):
ra_0 = ra_row0[ihw]
ra_1 = ra_row1[ihw]
rb_0 = rb_row0[ihw]
rb_1 = rb_row1[ihw]
rb_2 = rb_row2[ihw]
rb_3 = rb_row3[ihw]
s00 += ra_0 * rb_0; s01 += ra_0 * rb_1
s02 += ra_0 * rb_2; s03 += ra_0 * rb_3
s10 += ra_1 * rb_0; s11 += ra_1 * rb_1
s12 += ra_1 * rb_2; s13 += ra_1 * rb_3
out[NHW+0, g, o+0] = s00; out[NHW+0, g, o+1] = s01
out[NHW+0, g, o+2] = s02; out[NHW+0, g, o+3] = s03
out[NHW+1, g, o+0] = s10; out[NHW+1, g, o+1] = s11
out[NHW+1, g, o+2] = s12; out[NHW+1, g, o+3] = s13
# Remaining part for `o`
for o in range(so_vect_max, so):
for ihw in range(si*sh*sw):
out[NHW, g, o] += ra[NHW, g, ihw] * rb[g, o, ihw]
out[NHW+1, g, o] += ra[NHW+1, g, ihw] * rb[g, o, ihw]
# Remaining part for `k`
if (sN*sH*sW) % 2 == 1:
k = sN*sH*sW - 1
for o in range(so):
for ihw in range(si*sh*sw):
out[k, g, o] += ra[k, g, ihw] * rb[g, o, ihw]
return out.reshape((sN, sH, sW, sg, so))
This code is much more complex and uglier but also far more efficient. I did not implemented the tiling optimization since it would make the code even less readable. However, it should results in a significantly faster code on many-core processors (especially the ones with a small L2/L3 cache).
Here are performance results on my i5-9600KF 6-core processor:
method1: 816 ms
method2: 104 ms
method3: 40 ms
Theoretical optimal: 9 ms (optimistic lower bound)
The code is about 2.7 faster than method2
. There is a room for improvements since the optimal time is about 4 time better than method3
.
The main reason why Numba does not generate a fast code comes from the underlying JIT which fail to efficiently vectorize the loop. Implementing the tiling strategy should slightly improves the execution time very close to the optimal one. The tiling strategy is critical for much bigger arrays. This is especially true if so
is much bigger.
If you want a faster implementation you certainly need to write a C/C++ native code using directly SIMD instrinsics (which are unfortunately not portable) or a SIMD library (eg. XSIMD).
If you want an even faster implementation, then you need to use a faster hardware (with more cores) or a more dedicated one. Server-based GPUs (ie. not the one of personal computers) not should be able to speed up a lot such a computation since your input is small, clearly compute-bound and massively makes use of FMA floating-point operations. A first start is to try cupy.einsum
.
In order to understand why method1
is not faster, I checked the executed code. Here is the main loop:
1a0:┌─→; Part of the reduction (see below)
│ movapd xmm0,XMMWORD PTR [rdi-0x1000]
│
│ ; Decrement the number of loop cycle
│ sub r9,0x8
│
│ ; Prefetch items so to reduce the impact
│ ; of the latency of reading from the RAM.
│ prefetcht0 BYTE PTR [r8]
│ prefetcht0 BYTE PTR [rdi]
│
│ ; Part of the reduction (see below)
│ mulpd xmm0,XMMWORD PTR [r8-0x1000]
│
│ ; Increment iterator for the two arrays
│ add rdi,0x40
│ add r8,0x40
│
│ ; Main computational part:
│ ; reduction using add+mul SSE2 instructions
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1030]
│ mulpd xmm0,XMMWORD PTR [r8-0x1030]
│ addpd xmm1,xmm0 <--- Slow
│ movapd xmm0,XMMWORD PTR [rdi-0x1020]
│ mulpd xmm0,XMMWORD PTR [r8-0x1020]
│ addpd xmm0,xmm1 <--- Slow
│ movapd xmm1,XMMWORD PTR [rdi-0x1010]
│ mulpd xmm1,XMMWORD PTR [r8-0x1010]
│ addpd xmm1,xmm0 <--- Slow
│
│ ; Is the loop over?
│ ; If not, jump to the beginning of the loop.
├──cmp r9,0x7
└──jg 1a0
It turns out that Numpy use the SSE2 instruction set (which is available on all x86-64 processors). However, my machine, like almost all relatively recent processor support the AVX instruction set which can compute twice more items at once per instruction. My machine also support fuse-multiply add instructions (FMA) that are twice faster in this case. Moreover, the loop is clearly bounded by the addpd
which accumulate the result in mostly the same register. The processor cannot execute them efficiently since an addpd
takes few cycle of latency and up to two can be executed at the same time on modern x86-64 processors (which is not possible here since only 1 intruction can perform the accumulation in xmm1
at a time).
Here is the executed code of the main computational part of method2
(dgemm
call of OpenBLAS):
6a40:┌─→vbroadcastsd ymm0,QWORD PTR [rsi-0x60]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x58]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x50]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x48]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi-0x80]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi-0x60]
│ vbroadcastsd ymm0,QWORD PTR [rsi-0x40]
│ vbroadcastsd ymm1,QWORD PTR [rsi-0x38]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi-0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi-0x20]
│ vbroadcastsd ymm2,QWORD PTR [rsi-0x30]
│ vbroadcastsd ymm3,QWORD PTR [rsi-0x28]
│ vfmadd231pd ymm4,ymm0,YMMWORD PTR [rdi]
│ vfmadd231pd ymm5,ymm1,YMMWORD PTR [rdi+0x20]
│ vfmadd231pd ymm6,ymm2,YMMWORD PTR [rdi+0x40]
│ vfmadd231pd ymm7,ymm3,YMMWORD PTR [rdi+0x60]
│ add rsi,0x40
│ add rdi,0x100
├──dec rax
└──jne 6a40
This loop is far more optimized: it makes use of the AVX instruction set as well as the FMA one (ie. vfmadd231pd
instructions). Furthermore, the loop is better unrolled and there is not latency/dependency issue like in the Numpy code. However, while this loop is highly-efficient, the cores are not efficiently used due to some sequential checks done in Numpy and a sequential copy performed in OpenBLAS. Moreover, I am not sure the loop makes an efficient use of the cache in this case since a lot of read/writes are performed in RAM on my machine. Indeed, the RAM throughput about 15 GiB/s (over 35~40 GiB/s) due to many cache misses while the thoughput of method3
is 6 GiB/s (so more work is done in the cache) with a significantly faster execution.
Here is the executed code of the main computational part of method3
:
.LBB0_5:
vorpd 2880(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm2
vmovupd %ymm2, 3040(%rsp)
vorpd 2848(%rsp), %ymm8, %ymm1
vpcmpeqd %ymm2, %ymm2, %ymm2
vgatherqpd %ymm2, (%rsi,%ymm1,8), %ymm3
vmovupd %ymm3, 3104(%rsp)
vorpd 2912(%rsp), %ymm8, %ymm2
vpcmpeqd %ymm3, %ymm3, %ymm3
vgatherqpd %ymm3, (%rsi,%ymm2,8), %ymm4
vmovupd %ymm4, 3136(%rsp)
vorpd 2816(%rsp), %ymm8, %ymm3
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm3,8), %ymm5
vmovupd %ymm5, 3808(%rsp)
vorpd 2784(%rsp), %ymm8, %ymm9
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm9,8), %ymm5
vmovupd %ymm5, 3840(%rsp)
vorpd 2752(%rsp), %ymm8, %ymm10
vpcmpeqd %ymm4, %ymm4, %ymm4
vgatherqpd %ymm4, (%rsi,%ymm10,8), %ymm5
vmovupd %ymm5, 3872(%rsp)
vpaddq 2944(%rsp), %ymm8, %ymm4
vorpd 2720(%rsp), %ymm8, %ymm11
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rsi,%ymm11,8), %ymm5
vmovupd %ymm5, 3904(%rsp)
vpcmpeqd %ymm13, %ymm13, %ymm13
vgatherqpd %ymm13, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3552(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm1,8), %ymm5
vmovupd %ymm5, 3616(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm2,8), %ymm1
vmovupd %ymm1, 3648(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm3,8), %ymm1
vmovupd %ymm1, 3680(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm9,8), %ymm1
vmovupd %ymm1, 3712(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm10,8), %ymm1
vmovupd %ymm1, 3744(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm11,8), %ymm1
vmovupd %ymm1, 3776(%rsp)
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rsi,%ymm4,8), %ymm6
vpcmpeqd %ymm0, %ymm0, %ymm0
vgatherqpd %ymm0, (%rdx,%ymm4,8), %ymm3
vpaddq 2688(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm7
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3360(%rsp)
vpaddq 2656(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm13
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3392(%rsp)
vpaddq 2624(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm15
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3424(%rsp)
vpaddq 2592(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm9
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3456(%rsp)
vpaddq 2560(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm14
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3488(%rsp)
vpaddq 2528(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm11
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3520(%rsp)
vpaddq 2496(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm0,8), %ymm10
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3584(%rsp)
vpaddq 2464(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2432(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm12
vpaddq 2400(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3168(%rsp)
vpaddq 2368(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3200(%rsp)
vpaddq 2336(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3232(%rsp)
vpaddq 2304(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3264(%rsp)
vpaddq 2272(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3296(%rsp)
vpaddq 2240(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd %ymm4, 3328(%rsp)
vpaddq 2208(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vpaddq 2176(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 2976(%rsp)
vpaddq 2144(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3008(%rsp)
vpaddq 2112(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm5
vmovupd %ymm5, 3072(%rsp)
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rsi,%ymm8,8), %ymm0
vpcmpeqd %ymm5, %ymm5, %ymm5
vgatherqpd %ymm5, (%rdx,%ymm8,8), %ymm1
vmovupd 768(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm1, %ymm5
vmovupd %ymm5, 768(%rsp)
vmovupd 32(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm3, %ymm5
vmovupd %ymm5, 32(%rsp)
vmovupd 1024(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm2, %ymm5
vmovupd %ymm5, 1024(%rsp)
vmovupd 1280(%rsp), %ymm5
vfmadd231pd %ymm0, %ymm4, %ymm5
vmovupd %ymm5, 1280(%rsp)
vmovupd 1344(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1344(%rsp)
vmovupd 480(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 480(%rsp)
vmovupd 1600(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm6, %ymm0
vmovupd %ymm0, 1600(%rsp)
vmovupd 1856(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm6, %ymm0
vmovupd %ymm0, 1856(%rsp)
vpaddq 2080(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vpaddq 2048(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm4
vmovupd 800(%rsp), %ymm0
vmovupd 3552(%rsp), %ymm1
vmovupd 3040(%rsp), %ymm3
vfmadd231pd %ymm3, %ymm1, %ymm0
vmovupd %ymm0, 800(%rsp)
vmovupd 64(%rsp), %ymm0
vmovupd 3360(%rsp), %ymm5
vfmadd231pd %ymm3, %ymm5, %ymm0
vmovupd %ymm0, 64(%rsp)
vmovupd 1056(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm12, %ymm0
vmovupd %ymm0, 1056(%rsp)
vmovupd 288(%rsp), %ymm0
vmovupd 2976(%rsp), %ymm6
vfmadd231pd %ymm3, %ymm6, %ymm0
vmovupd %ymm0, 288(%rsp)
vmovupd 1376(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1376(%rsp)
vmovupd 512(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm7, %ymm0
vmovupd %ymm0, 512(%rsp)
vmovupd 1632(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm7, %ymm0
vmovupd %ymm0, 1632(%rsp)
vmovupd 1888(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1888(%rsp)
vmovupd 832(%rsp), %ymm0
vmovupd 3616(%rsp), %ymm1
vmovupd 3104(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 832(%rsp)
vmovupd 96(%rsp), %ymm0
vmovupd 3392(%rsp), %ymm3
vfmadd231pd %ymm6, %ymm3, %ymm0
vmovupd %ymm0, 96(%rsp)
vmovupd 1088(%rsp), %ymm0
vmovupd 3168(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 1088(%rsp)
vmovupd 320(%rsp), %ymm0
vmovupd 3008(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 320(%rsp)
vmovupd 1408(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm13, %ymm0
vmovupd %ymm0, 1408(%rsp)
vmovupd 544(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm13, %ymm0
vmovupd %ymm0, 544(%rsp)
vmovupd 1664(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm13, %ymm0
vmovupd %ymm0, 1664(%rsp)
vmovupd 1920(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm13, %ymm0
vmovupd %ymm0, 1920(%rsp)
vpaddq 2016(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm3
vmovupd 864(%rsp), %ymm0
vmovupd 3648(%rsp), %ymm1
vmovupd 3136(%rsp), %ymm6
vfmadd231pd %ymm6, %ymm1, %ymm0
vmovupd %ymm0, 864(%rsp)
vmovupd 128(%rsp), %ymm0
vmovupd 3424(%rsp), %ymm5
vfmadd231pd %ymm6, %ymm5, %ymm0
vmovupd %ymm0, 128(%rsp)
vmovupd 1120(%rsp), %ymm0
vmovupd 3200(%rsp), %ymm7
vfmadd231pd %ymm6, %ymm7, %ymm0
vmovupd %ymm0, 1120(%rsp)
vmovupd 352(%rsp), %ymm0
vmovupd 3072(%rsp), %ymm12
vfmadd231pd %ymm6, %ymm12, %ymm0
vmovupd %ymm0, 352(%rsp)
vmovupd 1440(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm15, %ymm0
vmovupd %ymm0, 1440(%rsp)
vmovupd 576(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm15, %ymm0
vmovupd %ymm0, 576(%rsp)
vmovupd 1696(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm15, %ymm0
vmovupd %ymm0, 1696(%rsp)
vmovupd 736(%rsp), %ymm0
vfmadd231pd %ymm12, %ymm15, %ymm0
vmovupd %ymm0, 736(%rsp)
vmovupd 896(%rsp), %ymm0
vmovupd 3808(%rsp), %ymm1
vmovupd 3680(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 896(%rsp)
vmovupd 160(%rsp), %ymm0
vmovupd 3456(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 160(%rsp)
vmovupd 1152(%rsp), %ymm0
vmovupd 3232(%rsp), %ymm7
vfmadd231pd %ymm1, %ymm7, %ymm0
vmovupd %ymm0, 1152(%rsp)
vmovupd 384(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 384(%rsp)
vmovupd 1472(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm9, %ymm0
vmovupd %ymm0, 1472(%rsp)
vmovupd 608(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm9, %ymm0
vmovupd %ymm0, 608(%rsp)
vmovupd 1728(%rsp), %ymm0
vfmadd231pd %ymm7, %ymm9, %ymm0
vmovupd %ymm0, 1728(%rsp)
vmovupd -128(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm9, %ymm0
vmovupd %ymm0, -128(%rsp)
vmovupd 928(%rsp), %ymm0
vmovupd 3840(%rsp), %ymm1
vmovupd 3712(%rsp), %ymm2
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 928(%rsp)
vmovupd 192(%rsp), %ymm0
vmovupd 3488(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 192(%rsp)
vmovupd 1184(%rsp), %ymm0
vmovupd 3264(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1184(%rsp)
vmovupd 416(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 416(%rsp)
vmovupd 1504(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm14, %ymm0
vmovupd %ymm0, 1504(%rsp)
vmovupd 640(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm14, %ymm0
vmovupd %ymm0, 640(%rsp)
vmovupd 1760(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm14, %ymm0
vmovupd %ymm0, 1760(%rsp)
vmovupd -96(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm14, %ymm0
vmovupd %ymm0, -96(%rsp)
vpaddq 1984(%rsp), %ymm8, %ymm0
vpcmpeqd %ymm1, %ymm1, %ymm1
vgatherqpd %ymm1, (%rdx,%ymm0,8), %ymm2
vmovupd 960(%rsp), %ymm0
vmovupd 3872(%rsp), %ymm1
vmovupd 3744(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 960(%rsp)
vmovupd 224(%rsp), %ymm0
vmovupd 3520(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 224(%rsp)
vmovupd 1216(%rsp), %ymm0
vmovupd 3296(%rsp), %ymm6
vfmadd231pd %ymm1, %ymm6, %ymm0
vmovupd %ymm0, 1216(%rsp)
vmovupd 448(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 448(%rsp)
vmovupd 1536(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm11, %ymm0
vmovupd %ymm0, 1536(%rsp)
vmovupd 672(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm11, %ymm0
vmovupd %ymm0, 672(%rsp)
vmovupd 1792(%rsp), %ymm0
vfmadd231pd %ymm6, %ymm11, %ymm0
vmovupd %ymm0, 1792(%rsp)
vmovupd -64(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm11, %ymm0
vmovupd %ymm0, -64(%rsp)
vmovupd 992(%rsp), %ymm0
vmovupd 3904(%rsp), %ymm1
vmovupd 3776(%rsp), %ymm3
vfmadd231pd %ymm1, %ymm3, %ymm0
vmovupd %ymm0, 992(%rsp)
vmovupd 256(%rsp), %ymm0
vmovupd 3584(%rsp), %ymm4
vfmadd231pd %ymm1, %ymm4, %ymm0
vmovupd %ymm0, 256(%rsp)
vmovupd 1248(%rsp), %ymm0
vmovupd 3328(%rsp), %ymm5
vfmadd231pd %ymm1, %ymm5, %ymm0
vmovupd %ymm0, 1248(%rsp)
vmovupd 1312(%rsp), %ymm0
vfmadd231pd %ymm1, %ymm2, %ymm0
vmovupd %ymm0, 1312(%rsp)
vmovupd 1568(%rsp), %ymm0
vfmadd231pd %ymm3, %ymm10, %ymm0
vmovupd %ymm0, 1568(%rsp)
vmovupd 704(%rsp), %ymm0
vfmadd231pd %ymm4, %ymm10, %ymm0
vmovupd %ymm0, 704(%rsp)
vmovupd 1824(%rsp), %ymm0
vfmadd231pd %ymm5, %ymm10, %ymm0
vmovupd %ymm0, 1824(%rsp)
vmovupd -32(%rsp), %ymm0
vfmadd231pd %ymm2, %ymm10, %ymm0
vmovupd %ymm0, -32(%rsp)
vpaddq 1952(%rsp), %ymm8, %ymm8
addq $-4, %rcx
jne .LBB0_5
The loop is huge and is clearly not vectorized properly: there is a lot of completely useless instructions and loads from memory appear not to be contiguous (see vgatherqpd
). Numba does not generate a good code since the underlying JIT (LLVM-Lite) fail to vectorize efficiently the code. In fact, I found out that a similar C++ code is badly vectorized by Clang 13.0 on a simplified example (GCC and ICC also fail on a more complex code) while an hand-written SIMD implementation works much better. It look like a bug of the optimizer or at least a missed optimization. This is why the Numba code is much slower than the optimal code. That being said, this implementation makes a quite efficient use of the cache and is properly multithreaded.
I also found out that the BLAS code is faster on Linux than Windows on my machine (with default packages coming from PIP and the same Numpy at version 1.20.3). Thus, the gap is closer between method2
and method3
but the later is still a significantly faster.