pythonarraysnumpymatrix-multiplicationmatrix-indexing

Performance degradation of matrix multiplication involving integer-array-indexed arrays in numpy


I'm working on a project where I have to perform some row and/or column permutaion before (broadcasted) matrix multiplication. While the implementation is straightforward with numpy, I noticed that performance varies dramatically on how permutation is applied. A minimal example with timing is as follows:

import time
import numpy as np

n = 800
nt = 50

a = np.random.randn(10, n)
b = np.random.randn(7, n, n)
p = np.random.permutation(n)

##############################################
t0 = time.time()
c = [a @ b for i in range(nt)]
t1 = time.time()
print('Time: ', t1-t0)          # ~ 0.2 seconds
##############################################
b1 = b[:, :, p]
t0 = time.time()
c = [a @ b1 for i in range(nt)]
t1 = time.time()
print('Time: ', t1-t0)          # ~ 0.22 seconds
##############################################
b2 = b[:, p, :][:, :, p]
t0 = time.time()
c = [a @ b2 for i in range(nt)]
t1 = time.time()
print('Time: ', t1-t0)          # ~ 4.1 seconds
##############################################
b3 = b[:, :, p][:, p, :]
t0 = time.time()
c = [a @ b3 for i in range(nt)]
t1 = time.time()
print('Time: ', t1-t0)          # ~ 12.5 seconds

In the above example, there's basically no performance change when the 3-d array has only its last dimension permuted. However, a significant performance degradation can be observed when the last two dimensions are both permuted - and the order also makes a difference.

After hours of investigation, I found that the strides from array_interface possibly tells the root of the above discrepancy:

print('b: ', b.__array_interface__['strides'])      # None
print('b1: ', b1.__array_interface__['strides'])    # (6400, 8, 44800)
print('b2: ', b2.__array_interface__['strides'])    # (8, 56, 44800)
print('b3: ', b3.__array_interface__['strides'])    # (8, 44800, 56)

The strides clearly show that none of b1, b2 or b3 is C-style contiguous. However, this behavior is completely unexpected. How does integer-array indexing work and how to circumvent performance degradation in this case?


Solution

  • The differences you see are the result of both how indexing is done, and how matmul is implemented. For some inputs, the arrays can be passed as is to the optimized library code, for others it requires some sort of copy.

    In [43]: n=800
    In [44]: a = np.random.randn(10, n)
        ...: b = np.random.randn(7, n, n)
        ...: p = np.random.permutation(n)
    
    In [45]: b1 = b[:, :, p] 
    In [46]: b2 = b[:, p, :][:, :, p]
    In [47]: b3 = b[:, :, p][:, p, :]
    

    As you note the indexed arrays have non-c-contiguous strides:

    In [48]: b.shape, b.strides
    Out[48]: ((7, 800, 800), (5120000, 6400, 8))
    
    In [49]: b1.shape, b1.strides
    Out[49]: ((7, 800, 800), (6400, 8, 44800))
    

    Advanced indexing does make a copy, but the final result may actually a transpose of a different base. This process isn't documentated (that I know of), but evident when we look at the base and strides.

    In [50]: b1.base.shape, b1.base.strides
    Out[50]: ((800, 7, 800), (44800, 6400, 8))
    

    So for some times:

    In [51]: timeit a@b
    12.3 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    In [52]: timeit a@b1
    The slowest run took 4.07 times longer than the fastest. This could mean that an intermediate result is being cached.
    19 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    As the other answer suggested we can try an extra layer of copying, so b1 is now its own base. But that may, or may not, save time:

    In [53]: timeit a@(b1.copy())
    128 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    In [54]: %%timeit bb=b1.copy() # taking copy outside
        ...: a@bb
    18.9 ms ± 878 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    But for b2 and b3 the extra copy does help:

    In [56]: timeit a@b2
    311 ms ± 372 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [57]: timeit a@(b2.copy())
    143 ms ± 583 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    In [58]: timeit a@b3
    1.22 s ± 8.79 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [59]: timeit a@(b3.copy())
    51.8 ms ± 158 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    A further observation. Your b is 3d. The fast code underlying matmul operates on 2 2d arrays. With 3d b, it has to iterate on the first, the batch, dimension. It iterates in c code, so isn't that slow, but still the timings are more sensitive (linear) in that dimension.

    Others can elaborate on how memory layout affects these timings. Some require less paging/caching that others, and thus are faster.


    For the other permutations, the bases are:

    In [69]: b2.base.shape, b2.base.strides
    Out[69]: ((800, 7, 800), (44800, 8, 56))
    
    In [70]: b3.base.shape, b3.base.strides
    Out[70]: ((800, 7, 800), (44800, 8, 56))
    

    Another way to do the double permutation is:

    In [71]: b2a = b[:, p[:,None], p]; b2a.shape, b2a.strides
    Out[71]: ((7, 800, 800), (8, 44800, 56))
    
    In [72]: np.allclose(b2,b2a)
    Out[72]: True
    

    This b2a is more like b3.

    but it doesn't help with timings:

    In [73]: timeit a@b2a
    1.27 s ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [74]: timeit a@b2
    346 ms ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    other operators

    Note that other operators, like element-wise multiplication, are not as sensitive to these indexing continuity issuses

    In [81]: timeit a[:7,None,:]*b
    23.2 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    In [82]: timeit a[:7,None,:]*b1
    32.8 ms ± 3.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    In [83]: timeit a[:7,None,:]*b2
    53.8 ms ± 3.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    In [84]: timeit a[:7,None,:]*b3
    35.6 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)