pythonnumpynumpy-einsum

Numpy matmul and einsum 6 to 7 times slower than MATLAB


I am trying to port some code from MATLAB to Python and I am getting much slower performance from Python. I am not very good at Python coding, so any advise to speed these up will be much appreciated.

I tried an einsum one-liner (takes 7.5 seconds on my machine):

import numpy as np

n = 4
N = 200
M = 100
X = 0.1*np.random.rand(M, n, N)
w = 0.1*np.random.rand(M, N, 1)

G = np.einsum('ijk,iljm,lmn->il', w, np.exp(np.einsum('ijk,ljn->ilkn',X,X)), w)

I also tried a matmult implementation (takes 6 seconds on my machine)

G = np.zeros((M, M))
for i in range(M):
    G[:, i] = np.squeeze(w[i,...].T @ (np.exp(X[i, :, :].T @ X) @ w))

But my original MATLAB code is way faster (takes 1 second on my machine)

n = 4;
N = 200;
M = 100;
X = 0.1*rand(n, N, M);
w = 0.1*rand(N, 1, M);

G=zeros(M);
for i=1:M
    G(:,i) = squeeze(pagemtimes(pagemtimes(w(:,1,i).', exp(pagemtimes(X(:,:,i),'transpose',X,'none'))) ,w));
end

I was expecting both Python implementations to be comparable in speed, but they are not. Any ideas why the Python implementations are this slow, or any suggestions to speed those up?


Solution

  • First of all np.einsum has a parameter optimize which is set to False by default (mainly because the optimization can be more expensive than the computation in some cases and it is better in general to pre-compute the optimal path in a separate call first). You can use optimal=True to significantly speed-up np.einsum (it provides the optimal path in this case though the internal implementation is not be optimal). Note that pagemtimes in Matlab is more specific than np.einsum so there is not need for such a parameter (i.e. it is fast by default in this case).

    Moreover, Numpy function like np.exp create a new array by default. The thing is computing arrays in-place is generally faster (and it also consumes less memory). This can be done thanks to the out parameter.

    The np.exp is pretty expensive on most machines because it runs serially (like most Numpy functions) and it is often not very optimized internally either. Using a fast math library like the one of Intel helps. I suspect Matlab uses such kind of fast math library internally. Alternatively, one can use multiple threads to compute this faster. This is easy to do with the numexpr package.

    Here is the resulting more optimized Numpy code:

    import numpy as np
    import numexpr as ne
    
    # [...] Same initialization as in the question
    
    tmp = np.einsum('ijk,ljn->ilkn',X,X, optimize=True)
    ne.evaluate('exp(tmp)', out=tmp)
    G = np.einsum('ijk,iljm,lmn->il', w, tmp, w, optimize=True)
    

    Performance results

    Here are results on my machine (with a i5-9600KF CPU, 32 GiB of RAM, on Windows):

    Naive einsums:        6.62 s
    CPython loops:        3.37 s
    This answer:          1.27 s   <----
    
    max9111 solution:     0.47 s   (using an unmodified Numba v0.57)
    max9111 solution:     0.54 s   (using a modified Numba v0.57)
    

    The optimized code is about 5.2 times faster than the initial code and 2.7 times faster than the initial fastest one!


    Note about performances and possible optimizations

    The first einsum takes a significant fraction of the runtime in the faster implementation on my machine. This is mainly because einsum perform many small matrix multiplications internally in a way that is not very efficient. Indeed, each matrix multiplication is done in parallel by a BLAS library (like OpenBLAS library which is the default one on most machines like mine). The thing is OpenBLAS is not efficient to compute small matrices in parallel. In fact, computing each small matrix in parallel is not efficient. A more efficient solution is to compute all the matrix multiplication in parallel (each thread should perform several serial matrix multiplication). This is certainly what Matlab does and why it can be a bit faster. This can be done using a parallel Numba code (or with Cython) and by disabling the parallel execution of BLAS routines (note this can have performance side effects on a larger script if it is done globally).

    Another possible optimization is to do all the operation at once in Numba using multiple threads. This solution can certainly reduce even more the memory footprint and further improve performance. However, this is far from being easy to write an optimized implementation and the resulting code will be significantly harder to maintain. This is what the max9111's code does.