pythonnumpynumpy-einsum

Efficiently compute all multi-dimensional traces for all offsets and store in matrix


I have a $N\times N \times N$ array $a$ and would like to implement the formula $$ b_{ij} = \sum_k a_{i+k,j+k,k} $$ efficiently.

Right now, I'm doing this via

b = np.zeros((N, N))
for i in range(N):
    for j in range(N):
        m = np.maximum(i, j)
        b[i,j] = np.einsum('iii', a[i:N-m+i,j:N-m+j,:N-m])

which seems quite inefficient.

Is there a way to do this without cython which works just with numpy (or any other, numpy-compatible interface such as jax.numpy)?

Edit 1: added missing explicit bounds.


Solution

  • It's not straightforward to vectorise because the sum is jagged, not rectangular. I propose index formation like the following, which seems fast enough. I also show a comparative benchmark.

    import time
    
    import numpy as np
    
    
    def nonvectorised(a: np.ndarray) -> np.ndarray:
        b = np.zeros(a.shape[:2], dtype=a.dtype)
        for i in range(a.shape[0]):
            for j in range(a.shape[1]):
                kmax = min(a.shape[0] - i, a.shape[1] - j)
                for k in range(kmax):
                    b[i, j] += a[i+k, j+k, k]
        return b
    
    
    def semivectorised(a: np.ndarray) -> np.ndarray:
        m, n = mn = a.shape[:2]
        b = np.zeros(mn, dtype=a.dtype)
    
        for klen in range(1, 1 + n):
            i2 = np.arange(n - klen + 1)[:, np.newaxis] + np.arange(klen)
            n_klen_n = range(n - klen, n)
            klen_0 = range(klen)
    
            b[n - klen, range(1 + n - klen)] = a[
                n_klen_n, i2, klen_0,
            ].sum(axis=1)
    
            b[range(n - klen), n - klen] = a[
                i2[:-1, :], n_klen_n, klen_0,
            ].sum(axis=1)
    
        return b
    
    
    def literal_trace(a: np.ndarray) -> np.ndarray:
        # https://stackoverflow.com/a/79676753/313768
        N = a.shape[0]
        b = np.zeros((N, N))
        for i in range(N):
            for j in range(N):
                u = a[i:, j:, :]
                b[i, j] = np.trace([u[k, k] for k in range(min(u.shape))])
        return b
    
    
    def test() -> None:
        rand = np.random.default_rng(seed=0)
        a = rand.integers(low=0, high=10, size=(12, 12, 12))
        b1 = nonvectorised(a)
        b2 = semivectorised(a)
        b3 = literal_trace(a)
        assert np.array_equal(b1, b2)
        assert np.array_equal(b1, b3.astype(a.dtype))
    
    
    def profile() -> None:
        rand = np.random.default_rng(seed=0)
        a = rand.integers(low=0, high=99, size=(100, 100, 100))
        for method in (semivectorised, literal_trace):
            t0 = time.perf_counter()
            method(a)
            t1 = time.perf_counter()
            print(f'{method.__name__}: {t1-t0:.4f}')
    
    
    if __name__ == '__main__':
        test()
        profile()
    
    semivectorised: 0.0082
    literal_trace: 0.2085
    

    Index caching

    Because you need to reuse the indexing within a loop, you can do

    
    class Tracer(typing.NamedTuple):
        n: int
        a: np.ndarray
        b: np.ndarray
        indices: tuple[tuple, ...]
    
        @classmethod
        def build_like(cls, x: np.ndarray) -> typing.Self:
            return cls.build(n=x.shape[0], dtype=x.dtype)
    
        @classmethod
        def build(cls, n: int, dtype: np.dtype) -> typing.Self:
            irange = np.arange(1 + n, dtype=np.int32)
            ij = irange[:, np.newaxis] + irange
    
            return cls(
                n=n,
                a=np.empty((n, n, n), dtype=dtype),
                b=np.empty((n, n), dtype=dtype),
                indices=tuple([
                    (
                        # a lower triangular indices
                        (r_klen_n := range(n - klen, n), ij[:n - klen + 1, :klen], r_klen := range(klen)),
                        # a upper triangular indices
                        (ij[:n - klen, :klen], r_klen_n, r_klen),
                        # b lower triangular indices
                        (n - klen, range(1 + n - klen)),
                        # b upper triangular indices
                        (range(n - klen), n - klen),
                    )
                    for klen in range(1, 1 + n)
                ]),
            )
    
        def trace(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
            a, b = self.a, self.b
            np.multiply(x, y[..., np.newaxis], out=a)
            for av1, av2, bv1, bv2 in self.indices:
                b[bv1] = a[av1].sum(axis=1)
                b[bv2] = a[av2].sum(axis=1)
    
            return self.b
    
    # ...
    tracer = Tracer.build_like(x)
    b = tracer.trace(x, y)