pythonnumpynumpy-einsum

Numpy Einsum - Why did this happen?


Can you explain why this happened?

import numpy as np
a = np.array([[1,2],
              [3,4],
              [5,6]
              ])
b = np.array([[2,2,2],
              [2,2,2]])
print(np.einsum("xy,zx -> yx",a,b))

and output of the code is:[[ 4 12 20] [ 8 16 24]] Which means the answer is calculated like this : ‍‍[1*2+1*2 , 3*2+3*2 , ...] But I expected it to be calculated like this: [[1*2 , 3*2 , 5*2],[2*2 , 4*2 , 6*2]] Where did I make a mistake?


Solution

  • Your code is equivalent to:

    (a[None] * b[..., None]).sum(axis=0).T
    

    You start with a (x, y) and b (z, x).

    First let's align the arrays:

    # a[None]                                shape: (1, x, y)
    array([[[1, 2],
            [3, 4],
            [5, 6]]])
    
    # b[..., None]                           shape: (z, x, 1)
    array([[[2],
            [2],
            [2]],
    
           [[2],
            [2],
            [2]]])
    

    and multiply:

    # a[None] * b[..., None]                 shape: (z, x, y)
    array([[[ 2,  4],
            [ 6,  8],
            [10, 12]],
    
           [[ 2,  4],
            [ 6,  8],
            [10, 12]]])
    

    sum over axis = 0 (z):

    # (a[None] * b[..., None]).sum(axis=0)   shape: (x, y)
    array([[ 4,  8],
           [12, 16],
           [20, 24]])
    

    Swap x and y:

    # (a[None] * b[..., None]).sum(axis=0).T shape: (y, x)
    
    array([[ 4, 12, 20],
           [ 8, 16, 24]])
    

    What you want is np.einsum('yx,xy->xy', a, b):

    array([[ 2,  6, 10],
           [ 4,  8, 12]])