I'm trying to optimize a particular piece of code to calculate the mahalanobis distance in a vectorized manner. I have a standard implementation which used traditional python multiplication, and another implementation which uses einsum. However, I'm surprised that the einsum implementation is slower than the standard python implementation. Is there anything I'm doing inefficiently in einsum, or are there potentially other methods such as tensordot that I should be looking into?
#SETUP
BATCH_SZ = 128
GAUSSIANS = 100
xvals = np.random.random((BATCH_SZ, 1, 4))
means = np.random.random((GAUSSIANS, 1, 4))
inv_covs = np.random.random((GAUSSIANS, 4, 4))
%%timeit
xvals_newdim = xvals[:, np.newaxis, ...]
means_newdim = means[np.newaxis, ...]
diff_newdim = xvals_newdim - means_newdim
regular = diff_newdim @ inv_covs @ (diff_newdim).transpose(0, 1, 3, 2)
>> 731 µs ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%%timeit
diff = xvals - means.squeeze(1)
einsum = np.einsum("ijk,jkl,ijl->ij", diff, inv_covs, diff)
>> 949 µs ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
First thing first. One need to understand what is going on to optimize such a code, then profile, then estimate the time, and only then find a better solution.
TL;DR: both versions are inefficient and runs serially. Neither BLAS libraries nor Numpy are designed for optimizing this use-case. In fact, even basic Numpy operations are not efficient when the last axis is very small (ie. 4). This can be optimized using Numba by writing an implementation specifically designed for your size of matrices.
@
is a Python operator but it calls Numpy function internally like +
or *
for example. It performs a loop iterating over all matrices and call a highly optimized BLAS implementation on each matrix. A BLAS is a numerical algebra library. There are many existing BLAS but the default one for Numpy is generally OpenBLAS which is pretty optimized, especially for large matrices. Please note also that np.einsum
can call BLAS implementations in specific pattern (if the optimize
flag is set properly though) but this is not the case here. It is also worth mentioning that np.einsum
is well optimized for 2 matrices in input but less-well for 3 matrices and not optimized for more matrices in parameter. This is because the number of possibility grows exponentially and that the code do the optimization manually. For more information about how np.einsum
works, please read How is numpy.einsum implemented?.
The thing is you are multiplying many very-small matrices and most BLAS implementations are not optimized for that. In fact, Numpy either: the cost of the generic loop iteration can become big compared to the computation, not to mention the function call to the BLAS. A profiling of Numpy shows that the slowest function of the np.einsum
implementation is PyArray_TransferNDimToStrided
. This function is not the main computing function but a helper one. In fact, the main computing function takes only 20% of the overall time which leaves a lot of room for improvement! The same is true for the BLAS implementation: cblas_dgemv
only takes about 20% as well as dgemv_n_HASWELL
(the main computing kernel of the BLAS cblas_dgemv
function). The rest is nearly pure overhead of the BLAS library or Numpy (roughly half the time for both). Moreover, both version runs serially. Indeed, np.einsum
is not optimized to run with multiple threads and the BLAS cannot use multiple threads since the matrices are too small so multiple threads can be useful (since multi-threading has a significant overhead). This means both versions are pretty inefficient.
To know how inefficient the versions are, one need to know the amount of computation to do and the speed of the processor. The number of Flop (floating-point operation)is provided by np.einsum_path
and is 5.120e+05 (for an optimized implementation, otherwise it is 6.144e+05). Mainstream CPUs usually performs >=100 GFlops/s with multiple threads and dozens of GFlops/s serially. For example my i5-9600KF processor can achieve 300-400 GFlops/s in parallel and 50-60 GFlops/s serially. Since the computation last for 0.52 ms for the BLAS version (best), this means the code runs at 1 GFlops/s which is a poor result compared to the optimal.
On solution to speed up the computation is to design a Numba (JIT compiler) or Cython (Python to C compiler) implementation that is optimized for your specific sizes of matrices. Indeed, the last dimension is too small for generic codes to be fast. Even a basic compiled code would not be very fast in this case: even the overhead of a C loop can be quite big compared to the actual computation. We can tell to the compiler that the size some matrix axis is small and fixed at compilation time so the compiler can generate a much faster code (thanks to loop unrolling, tiling and SIMD instructions). This can be done with a basic assert in Numba. In addition, we can use the fastmath=True
flag so to speed the computation even more if there is no special floating-point (FP) values like NaN or subnormal numbers used. This flag can also impact the accuracy of the result since is assume FP math is associative (which is not true). Put it shortly, it breaks the IEEE-754 standard for sake of performance. Here is the resulting code:
import numba as nb
# use `fastmath=True` for better performance if there is no
# special value used and the accuracy is not critical.
@nb.njit('(float64[:,:,::1], float64[:,:,::1])', fastmath=True)
def compute_fast_einsum(diff, inv_covs):
ni, nj, nk = diff.shape
nl = inv_covs.shape[2]
assert inv_covs.shape == (nj, nk, nl)
assert nk == 4 and nl == 4
res = np.empty((ni, nj), dtype=np.float64)
for i in range(ni):
for j in range(nj):
s = 0.0
for k in range(nk):
for l in range(nl):
s += diff[i, j, k] * inv_covs[j, k, l] * diff[i, j, l]
res[i, j] = s
return res
%%timeit
diff = xvals - means.squeeze(1)
compute_fast_einsum(diff, inv_covs)
Here are performance results on my machine (mean ± std. dev. of 7 runs, 1000 loops each):
@ operator: 602 µs ± 3.33 µs per loop
einsum: 698 µs ± 4.62 µs per loop
Numba code: 193 µs ± 544 ns per loop
Numba + fastmath: 177 µs ± 624 ns per loop
Best Numba: < 100 µs <------ 6x-7x faster !
Note that 100 µs is spent in the computation of diff
which is not efficient. This one can be also optimized with Numba. In fact, the value of diff
can be compute on the fly in the i
-based loop from other arrays. This make the computation more cache friendly. This version is called "best Numba" in the results. Note that the Numba versions are not even using multiple threads. That being said, the overhead of multi-threading is generally about 5-500 µs so it may be slower on some machine to use multiple threads (on mainstream PCs, ie. not computing server, the overhead is generally 5-100 µs and it is about 10 µs on my machine).