I have a tensor operation that I would like to replicate using a combination of torch.stack()
and torch.tensordot()
to generalize it further on a larger program. In summary, I want to replicate the tensor V_1
using said operations into another tensor called V_2
.
N, t , J = 4, 2 , 3
K_f , K_r = 1, 1
R = 5
K = K_f + K_r
id = torch.arange(N).repeat(t).sort()
X = torch.randn(N*t, K , J)
Y = torch.randn(N*t, 1)
D = torch.randn(N, K_r , R)
Draw = D.repeat_interleave(t,0)
beta = torch.randn(2*K_r + K_f, 1)
beta_R = (beta[0:K_r,0] + beta[K_r:2*K_r,0] * Draw ).repeat(1,J,1)
print("shape beta_R:", beta_R.shape)
beta_F = beta[2*K_r:2*K_r + K_f,0].repeat(N*t, J, R)
print("shape beta_F:", beta_F.shape)
XX_0 =X[:,0,:].unsqueeze(2).repeat(1,1,R)
print("shape XX_0:", XX_0.shape)
XX_1 =X[:,1,:].unsqueeze(2).repeat(1,1,R)
print("shape XX_1:", XX_1.shape)
V_1 = XX_0 * beta_R + XX_1 * beta_F
print("shape V_1:",V_1.shape)
#shape beta_R: torch.Size([8, 3, 5])
#shape beta_F: torch.Size([8, 3, 5])
#shape XX_0: torch.Size([8, 3, 5])
#shape XX_1: torch.Size([8, 3, 5])
#shape V_1: torch.Size([8, 3, 5])
Now I want to do the same but stacking my tensors (using torch.stack()
) and applying a generalized version of the dot-product (using torch.tensordot()
), but I am a bit confused with the dims
argument which is not doing what I expected.
#%% Replicating using stacking and tensordot
stack_XX = torch.stack((XX_0, XX_1), 0)
print("shape stack_XX:",stack_XX.shape)
stack_beta = torch.stack((beta_R, beta_F), 0)
print("shape stack_beta:", stack_beta.shape)
# dot product bewteen stack_XX and stack_beta along the first dimension
V_2 = torch.tensordot(stack_XX, stack_beta, dims=([0], [0]))
print("shape V_2:",V_2.shape)
# check if the two are equal
torch.all(V_1.eq(V_2))
#shape stack_XX: torch.Size([2, 8, 3, 5])
#shape stack_beta: torch.Size([2, 8, 3, 5])
#shape V_2: torch.Size([8, 3, 5, 8, 3, 5])
#tensor(False)
So I am basically trying to get tensor(True)
when running torch.all(V_1.eq(V_2))
.
May be?
torch.einsum( 'abcd,abcd->bcd', stack_XX, stack_beta)