I'm implementing self-attention part in transformer encoder using pytorch nn.MultiheadAttention
and confusing in the padding masking of transformer.
The following picture shows the self-attention weight of the query (row) and key (column).
As you can see, there are some tokens "<PAD>" and I have already mask it in key. Therefore the tokens will not calculate the attention weight.
There are still two questions:
In query part, can I also mask them("<PAD>") except for the red square part? Is this reasonable?
How can I mask "<PAD>" in the query?
The attention weights also use the softmax
function along the row by giving mask in src_mask
or src_key_padding_mask
argument. If I set all the "<PAD>" row into -inf
, the softmax
will return nan
and the loss with be nan
There is no need to mask the queries during self-attention, it should be enough if do not use the states corresponding to the <PAD>
tokens later in the network (either as hidden states or keys/values), they will not influence the loss function nor anything else in the network.
If you want to make sure that you did not make a bug causing the gradient flowing through the <PAD>
tokens you can explicitly zero-out the self-attention using torch.where
after it is computed.