pythonpytorch

Transformer encoder layer with pytorch : The shape of the 2D attn_mask is torch.Size([16, 512]), but should be (16, 16)


Here it's an minimal exemple of code :

encoder_layers = nn.TransformerEncoderLayer(512, 8,2048 ,0.5)
mask = torch.randint(0,2, (16,512)).bool()
text = torch.randn(16,512)
print(mask)
print(text)
encoder_layers(text,mask)

This gives me the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-31-b326564457ab> in <module>
      4 print(mask)
      5 print(text)
----> 6 encoder_layers(text,mask)

5 frames
/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v, average_attn_weights)
   5067             correct_2d_size = (tgt_len, src_len)
   5068             if attn_mask.shape != correct_2d_size:
-> 5069                 raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
   5070             attn_mask = attn_mask.unsqueeze(0)
   5071         elif attn_mask.dim() == 3:

RuntimeError: The shape of the 2D attn_mask is torch.Size([16, 512]), but should be (16, 16).

I don't understand why that doesn't work because the mask length will be equal to the number of the tokens ?


Solution

  • Mask represents which query vectors can attend to which key vectors in the attention section. For example, in machine translation, the training batch has the entire sentence in the target language, but we don't want the queries at each word in the target language to attend to the keys for future words in that sentence. So at training time we would apply a mask to filter out the future keys for each query. Therefore the attention mask should be of shape [len(queries), len(keys)]. In your example, len(queries) = len(keys) = 16.