Let T
and L
be two batches of matrices (MxN)
and a function f(ti,lj)
that calculates a score for matrices ti
and lj
. For instance, if
T, L= torch.rand(4,3,2), torch.rand(4,3,2)
# T = tensor([[[0.0017, 0.5781],
# [0.8136, 0.5971],
# [0.7697, 0.0795]],
# [[0.2794, 0.7285],
# [0.1528, 0.8503],
# [0.9714, 0.1060]],
# [[0.6907, 0.8831],
# [0.4691, 0.4254],
# [0.2539, 0.7538]],
# [[0.3717, 0.2229],
# [0.6134, 0.4810],
# [0.7595, 0.9449]]])
and the score function is defined as shown in the following code snippet:
def score(ti, lj):
"""MaxSim score of matrix ti and lj
"""
m = torch.matmul(ti, torch.transpose(lj, 0, 1))
return torch.sum(torch.max(m, 1).values, dim=-1)
How to return a score matrix S
, where S[i,j]
represents the score between T[i]
and L[j]
?
#S = tensor([[2.3405, 2.2594, 2.0989, 1.6450],
# [2.5939, 2.4186, 2.3946, 2.0648],
# [2.9447, 2.3652, 2.3829, 2.1536],
# [2.8195, 2.3105, 2.2563, 1.8388]])
NOTE: This operation must be differentiable.
I'd recommend using einsum
for the pair wise matrix multiplication
m = torch.einsum('b i j, c k j -> b c i k', T, L)
which results in
>>> m.shape
torch.Size([4, 4, 3, 3])
that is, a tensor that contains all 16 matrix products. Then the rest is simply
out = torch.max(m, -1).values.sum(dim=-1)
Alternatively you could use broadcasting for the matrix multiplications, but I think it is quite a bit more cumbersome than the einsum
solution.