I am learning to implement the Factorization Machine in Pytorch. And there should be some feature crossing operations. For example, I've got three features [A,B,C], after embedding, they are [vA,vB,vC], so the feature crossing is "[vA·vB], [vA·vC], [vB·vc]".
I know this operation can be simplified by the following:
It can be implemented by MATRIX OPERATIONS. But this only gives a final result, say, a single value.
The question is, how to get all cross_vec in the following without doing FOR loop: note: size of "feature_emb" is [batch_size x feature_len x embedding_size]
g_feature = 0
for i in range(self.featurn_len):
for j in range(self.featurn_len):
if j <= i: continue
cross_vec = feature_emb[:,i,:] * feature_emb[:,j,:]
g_feature += torch.sum(cross_vec, dim=1)
You can
cross_vec = (feature_emb[:, None, ...] * feature_emb[..., None, :]).sum(dim=-1)
This should give you corss_vec
of shape (batch_size, feature_len, feature_len)
.
Alternatively, you can use torch.bmm
cross_vec = torch.bmm(feature_emb, feature_emb.transpose(1, 2))