pythonpytorchtransformer-model

Issue with Padding Mask in PyTorch Transformer Encoder


I'm encountering an issue with the padding mask in PyTorch's Transformer Encoder. I'm trying to ensure that the values in the padded sequences do not affect the output of the model. However, even after setting the padded values to zeros in the input sequence, I'm still observing differences in the output.

Here's a simplified version of my code:

import torch as th
from torch import nn

# Data
batch_size = 2
seq_len = 5
input_size = 16
src = th.randn(batch_size, seq_len, input_size)

# Set some values to a high value
src[0, 2, :] = 1000.0
src[1, 4, :] = 1000.0

# Generate a padding mask
padding_mask = th.zeros(batch_size, seq_len, dtype=th.bool)
padding_mask[0, 2] = 1
padding_mask[1, 4] = 1

# Pass the data through the encoder of the model
encoder = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(
        d_model=input_size,
        nhead=1,
        batch_first=True,
    ),
    num_layers=1,
    norm=None,
)
out1000 = encoder(src, src_key_padding_mask=padding_mask)

# Modify the input data so that the masked vector does not affect
src[0, 2, :] = 0.0
src[1, 4, :] = 0.0

# Pass the modified data through the model
out0 = encoder(src, src_key_padding_mask=padding_mask)

# Check if the results are the same
assert th.allclose(
    out1000[padding_mask == 0],
    out0[padding_mask == 0],
    atol=1e-5,
)

Despite setting the padded values to zeros in the input sequence, I'm still observing differences in the output of the Transformer Encoder. Could someone please help me understand why this might be happening? How can I ensure that the values in the padded sequences do not affect the output of the model?


Solution

  • The discrepancy is caused by dropout in the encoder layer. You can fix this by passing dropout=0.0 to TransformerEncoderLayer

    encoder = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(
            d_model=input_size,
            nhead=1,
            batch_first=True,
            dropout=0.0
        ),
        num_layers=1,
        norm=None,
    )