I need to optimize a matrix vector multiplication. The data looks like following:
Some non-function requirements are also have to be met for this routine:
std::vector
for example)Eigen
or Blas
for me, either)This is my (simplified, where I assume the input is perfectly blocked, for sake of readability) code,
// input_height = 90000
// input_width = 81
for (uint32_t y = 0; y < input_height; y += 4) {
float32x4_t sum0 = vmovq_n_f32(0);
float32x4_t sum1 = vmovq_n_f32(0);
float32x4_t sum2 = vmovq_n_f32(0);
float32x4_t sum3 = vmovq_n_f32(0);
for (uint32_t x = 0; x < input_width; x += 16) {
float32x4x4_t A = load_matrix_transpose(kernel + x);
float32x4x4_t B0 = load_matrix_transpose(input + y * input_width + x);
float32x4x4_t B1 = load_matrix_transpose(input + (y + 1) * input_width + x);
float32x4x4_t B2 = load_matrix_transpose(input + (y + 2) * input_width + x);
float32x4x4_t B3 = load_matrix_transpose(input + (y + 3) * input_width + x);
matrix_element_wise_multiplication(A, B0, sum0);
matrix_element_wise_multiplication(A, B1, sum1);
matrix_element_wise_multiplication(A, B2, sum2);
matrix_element_wise_multiplication(A, B3, sum3);
}
output[y] = vaddvq_f32(sum0);
output[y + 1] = vaddvq_f32(sum1);
output[y + 2] = vaddvq_f32(sum2);
output[y + 3] = vaddvq_f32(sum3);
}
Where the load_matrix_transpose
, matrix_element_wise_multiplication
are the following functions:
inline float32x4x4_t load_matrix_transpose(float *a) {
float32x4x4_t ret;
ret.val[0] = simd_load(a);
ret.val[1] = simd_load(a + 4);
ret.val[2] = simd_load(a + 8);
ret.val[3] = simd_load(a + 12);
return ret;
}
inline void simd_matrix_element_wise_multiplication(float32x4x4_t & A, float32x4x4_t & B, float32x4x4_t & C) {
C = vmlaq_f32(C, A.val[0], B.val[0]);
C = vmlaq_f32(C, A.val[1], B.val[1]);
C = vmlaq_f32(C, A.val[2], B.val[2]);
C = vmlaq_f32(C, A.val[3], B.val[3]);
}
On my Rasperry Pi 4 (ARMv8, 8GB RAM, 4 cores) the code takes with optimization level -O3
about 60ms
.
On long run (many loops), the Neon register version is exactly twice as fast as the normal code.
My question is, is there anyway to optimize the code further? I have tried many things but can not make any improvement with respect to the normal code.
Data locality is the highest priority when it comes to optimizations, and you should be aware of the register capacity since registers are BY FAR the fastest and most scarce resource.
aarch64
: 32x128bit neon
registers (512 bytes)
aarch32
: 16x128bit neon
registers (256 bytes)
A 81x90000 matrix when transposed requires to hold 90000 intermediate values to do the multiplication, and since 360000 bytes don't fit into a register bank of 512 bytes, there will be TONS of memory swapping which translates in HUGE performance hits.
On the other hand, 4*81 bytes of the vector fit nicely into the 512 bytes.
void matVecMult81x90000(float *pDst, float *pMat, float *pVec)
{
register float32x4_t vec0_3, vec4_7, vec8_11, vec12_15, vec16_19, vec20_23, vec24_27, vec28_31, vec32_35, vec36_39, vec40_43, vec44_47, vec48_51, vec52_55, vec56_59, vec60_63, vec64_67, vec68_71, vec72_75, vec76_79, vec80;
register float32x4_t mat0, mat1, mat2, mat3, mat4, rslt;
register float32x2_t drslt;
register uint32_t nRows = 90000;
vec80 = vdupq_n_f32(0.0f);
mat4 =vdupq_n_f32(0.0f);
vec0_3 = vld1q_f32(pVec); pVec += 4;
vec4_7 = vld1q_f32(pVec); pVec += 4;
vec8_11 = vld1q_f32(pVec); pVec += 4;
vec12_15 = vld1q_f32(pVec); pVec += 4;
vec16_19 = vld1q_f32(pVec); pVec += 4;
vec20_23 = vld1q_f32(pVec); pVec += 4;
vec24_27 = vld1q_f32(pVec); pVec += 4;
vec28_31 = vld1q_f32(pVec); pVec += 4;
vec32_35 = vld1q_f32(pVec); pVec += 4;
vec36_39 = vld1q_f32(pVec); pVec += 4;
vec40_43 = vld1q_f32(pVec); pVec += 4;
vec44_47 = vld1q_f32(pVec); pVec += 4;
vec48_51 = vld1q_f32(pVec); pVec += 4;
vec52_55 = vld1q_f32(pVec); pVec += 4;
vec56_59 = vld1q_f32(pVec); pVec += 4;
vec60_63 = vld1q_f32(pVec); pVec += 4;
vec64_67 = vld1q_f32(pVec); pVec += 4;
vec68_71 = vld1q_f32(pVec); pVec += 4;
vec72_75 = vld1q_f32(pVec); pVec += 4;
vec76_79 = vld1q_f32(pVec); pVec += 4;
vld1q_lane_f32(pVec, vec80, 0);
do {
mat0 = vld1q_f32(pMat); pMat += 4;
mat1 = vld1q_f32(pMat); pMat += 4;
mat2 = vld1q_f32(pMat); pMat += 4;
mat3 = vld1q_f32(pMat); pMat += 4;
rslt = vmulq_f32(mat0, vec0_3);
rslt += vmulq_f32(mat1, vec4_7);
rslt += vmulq_f32(mat2, vec8_11);
rslt += vmulq_f32(mat3, vec12_15);
mat0 = vld1q_f32(pMat); pMat += 4;
mat1 = vld1q_f32(pMat); pMat += 4;
mat2 = vld1q_f32(pMat); pMat += 4;
mat3 = vld1q_f32(pMat); pMat += 4;
rslt += vmulq_f32(mat0, vec16_19);
rslt += vmulq_f32(mat1, vec20_23);
rslt += vmulq_f32(mat2, vec24_27);
rslt += vmulq_f32(mat3, vec28_31);
mat0 = vld1q_f32(pMat); pMat += 4;
mat1 = vld1q_f32(pMat); pMat += 4;
mat2 = vld1q_f32(pMat); pMat += 4;
mat3 = vld1q_f32(pMat); pMat += 4;
rslt += vmulq_f32(mat0, vec32_35);
rslt += vmulq_f32(mat1, vec36_39);
rslt += vmulq_f32(mat2, vec40_43);
rslt += vmulq_f32(mat3, vec44_47);
mat0 = vld1q_f32(pMat); pMat += 4;
mat1 = vld1q_f32(pMat); pMat += 4;
mat2 = vld1q_f32(pMat); pMat += 4;
mat3 = vld1q_f32(pMat); pMat += 4;
rslt += vmulq_f32(mat0, vec48_51);
rslt += vmulq_f32(mat1, vec52_55);
rslt += vmulq_f32(mat2, vec56_59);
rslt += vmulq_f32(mat3, vec60_63);
mat0 = vld1q_f32(pMat); pMat += 4;
mat1 = vld1q_f32(pMat); pMat += 4;
mat2 = vld1q_f32(pMat); pMat += 4;
mat3 = vld1q_f32(pMat); pMat += 4;
vld1q_lane_f32(pMat, mat4, 0); pMat += 1;
rslt += vmulq_f32(mat0, vec64_67);
rslt += vmulq_f32(mat1, vec68_71);
rslt += vmulq_f32(mat2, vec72_75);
rslt += vmulq_f32(mat3, vec76_79);
rslt += vmulq_f32(mat4, vec80);
*pDst++ = vaddvq_f32(rslt);
} while (--nRows);
}
Unfortunately, compilers don't play along nicely. (Both GCC and Clang)
The generated code shows some stack swapping on the Vector inside the loop.
Below is the same function in hand written assembly without any stack swapping:
.arch armv8-a
.global matVecMult81x90000_asm
.text
.balign 64
.func
matVecMult81x90000_asm:
// init loop counter
mov w3, #90000 & 0xffff
movk w3, #90000>>16, lsl #16
// preserve registers
stp d8, d9, [sp, #-48]!
stp d10, d11, [sp, #1*16]
stp d12, d13, [sp, #2*16]
// load vectors
ldp q0, q1, [x2, #0*32]
ldp q2, q3, [x2, #1*32]
ldp q4, q5, [x2, #2*32]
ldp q6, q7, [x2, #3*32]
ldp q8, q9, [x2, #4*32]
ldp q10, q11, [x2, #5*32]
ldp q12, q13, [x2, #6*32]
ldp q16, q17, [x2, #7*32]
ldp q18, q19, [x2, #8*32]
ldp q20, q21, [x2, #9*32]
ldr s22, [x2, #10*32]
// loop
.balign 64
1:
ldp q24, q25, [x1, #0*32]
ldp q26, q27, [x1, #1*32]
ldp q28, q29, [x1, #2*32]
ldp q30, q31, [x1, #3*32]
subs w3, w3, #1
fmul v23.4s, v24.4s, v0.4s
fmla v23.4s, v25.4s, v1.4s
fmla v23.4s, v26.4s, v2.4s
fmla v23.4s, v27.4s, v3.4s
fmla v23.4s, v28.4s, v4.4s
fmla v23.4s, v29.4s, v5.4s
fmla v23.4s, v30.4s, v6.4s
fmla v23.4s, v31.4s, v7.4s
ldp q24, q25, [x1, #4*32]
ldp q26, q27, [x1, #5*32]
ldp q28, q29, [x1, #6*32]
ldp q30, q31, [x1, #7*32]
fmla v23.4s, v24.4s, v8.4s
fmla v23.4s, v25.4s, v9.4s
fmla v23.4s, v26.4s, v10.4s
fmla v23.4s, v27.4s, v11.4s
fmla v23.4s, v28.4s, v12.4s
fmla v23.4s, v29.4s, v13.4s
fmla v23.4s, v30.4s, v16.4s
fmla v23.4s, v31.4s, v17.4s
ldp q24, q25, [x1, #8*32]
ldp q26, q27, [x1, #9*32]
ldr s28, [x1, #10*32]
fmla v23.4s, v24.4s, v18.4s
fmla v23.4s, v25.4s, v19.4s
fmla v23.4s, v26.4s, v20.4s
fmla v23.4s, v27.4s, v21.4s
fmla v23.4s, v28.4s, v22.4s
add x1, x1, #81*4
faddp v23.4s, v23.4s, v23.4s
faddp v23.2s, v23.2s, v23.2s
str s23, [x0], #4
b.ne 1b
.balign 8
//restore registers
ldp d10, d11, [sp, #1*16]
ldp d12, d13, [sp, #2*16]
ldp d8, d9, [sp], #48
// return
ret
.endfunc
.end
Test results on RK3368:
Clang intrinsics: 10.41ms
assembly: 9.59ms
The compilers didn't perform that bad in this case, but more than often they are unbelievably stupid. I strongly recommend learning assembly.