Hello could someone please help me figure out how to use np.einsum to produce the below code's result. I have a (3,3,3) tensor and I will like to get this results which I got from using two for loops. the code I wrote to produce this output is below. I am trying to use np.einsum to produce this same result attained from using two for loops in the below code. I am not familar with using np.einsum. Ideally I will also like to sum each of the resulting rows to get nine values.
Command Line Arguments
result of code below
[1 1 1]
[2 2 2]
[1 1 1]
[2 2 2]
[4 4 4]
[2 2 2]
[1 1 1]
[2 2 2]
[1 1 1]
[1 1 1]
3
6
3
9
12
6
15
18
9
6
12
6
18
24
12
import numpy as np
bb=[]
for x in range(3):
for y in range(3):
bb.append((x,y))
a = np.array([[[1,2,1],[3,4,2],[5,6,3]],
[[1,2,1],[3,4,2],[5,6,3]],
[[1,2,1],[3,4,2],[5,6,3]]])
b = np.array([[[1,2,1],[3,4,2],[5,6,3]],
[[1,2,1],[3,4,2],[5,6,3]],
[[1,2,1],[3,4,2],[5,6,3]]])
for z in range(9):
llAI = bb[z]
aal = a[:,llAI[0],llAI[1]]
for f in range(9):
mmAI=bb[f]
aam = a[:,mmAI[0],mmAI[1]]
print(np.sum(aal*aam))
It took a bit to figure out what you are doing,
Since z
iterates on range(3)
, aal
is successively a[:,0,0]
, a[:,0,1]
,a[:,0,2]
.
Or done all at once:
In [178]: aaL = a[:,0,:]; aaL
Out[178]:
array([[1, 2, 1],
[1, 2, 1],
[1, 2, 1]])
aam
does the same iteration. So the sum of their products, using matmul/@/dot
is:
In [179]: aaL.T@aaL
Out[179]:
array([[ 3, 6, 3],
[ 6, 12, 6],
[ 3, 6, 3]])
Or in einsum
:
In [180]: np.einsum('ji,jk->ik',aaL,aaL)
Out[180]:
array([[ 3, 6, 3],
[ 6, 12, 6],
[ 3, 6, 3]])
Your indexing array:
In [183]: bb
Out[183]: [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)
In [185]: np.array(bb)[:3,:]
Out[185]:
array([[0, 0],
[0, 1],
[0, 2]])
If I generalize it to the remaining ranges of bb
:
In [192]: for i in range(3):
...: aaL = a[:,i]
...: print(aaL.T@aaL)
...:
[[ 3 6 3]
[ 6 12 6]
[ 3 6 3]]
[[27 36 18]
[36 48 24]
[18 24 12]]
[[ 75 90 45]
[ 90 108 54]
[ 45 54 27]]
Adding a dimension to the einsum
:
In [195]: np.einsum('jmi,jmk->mik', a,a)
Out[195]:
array([[[ 3, 6, 3],
[ 6, 12, 6],
[ 3, 6, 3]],
[[ 27, 36, 18],
[ 36, 48, 24],
[ 18, 24, 12]],
[[ 75, 90, 45],
[ 90, 108, 54],
[ 45, 54, 27]]])