I'm following Pytorch seq2seq tutorial and ittorch.bmm
method is used like below:
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
I understand why we need to multiply attention weight and encoder outputs.
What I don't quite understand is the reason why we need bmm
method here.
torch.bmm
document says
Performs a batch matrix-matrix product of matrices stored in batch1 and batch2.
batch1 and batch2 must be 3-D tensors each containing the same number of matrices.
If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, out will be a (b×n×p) tensor.
In the seq2seq model, the encoder encodes the input sequences given in as mini-batches. Say for example, the input is B x S x d
where B is the batch size, S is the maximum sequence length and d is the word embedding dimension. Then the encoder's output is B x S x h
where h is the hidden state size of the encoder (which is an RNN).
Now while decoding (during training)
the input sequences are given one at a time, so the input is B x 1 x d
and the decoder produces a tensor of shape B x 1 x h
. Now to compute the context vector, we need to compare this decoder hidden state with the encoder's encoded states.
So, consider you have two tensors of shape T1 = B x S x h
and T2 = B x 1 x h
. So if you can do batch matrix multiplication as follows.
out = torch.bmm(T1, T2.transpose(1, 2))
Essentially you are multiplying a tensor of shape B x S x h
with a tensor of shape B x h x 1
and it will result in B x S x 1
which is the attention weight for each batch.
Here, the attention weights B x S x 1
represent a similarity score between the decoder's current hidden state and encoder's all the hidden states. Now you can take the attention weights to multiply with the encoder's hidden state B x S x h
by transposing first and it will result in a tensor of shape B x h x 1
. And if you perform squeeze at dim=2, you will get a tensor of shape B x h
which is your context vector.
This context vector (B x h
) is usually concatenated to decoder's hidden state (B x 1 x h
, squeeze dim=1) to predict the next token.