audiodeep-learningpytorchtransformer-modelpattern-recognition

How to correctly apply LayerNorm after MultiheadAttention with different input shapes (batch_first vs default) in PyTorch?


I’m working on an audio recognition task using a Transformer-based model in PyTorch. My input features are generated by a CNN-based embedding layer and have the shape [batch_size, d_model, n_token], where n_token is the sequence length and d_model is the feature dimension.

By default, nn.MultiheadAttention (when batch_first=False) expects input in the shape (seq, batch, feature). To make things more intuitive, I chose to set batch_first=True and then permute my data from [batch_size, d_model, n_token] to [batch_size, n_token, d_model] so that the time dimension comes before the feature dimension. Here’s a simplified code snippet:

# Original shape: [batch_size, d_model, n_token]
data = concat_cls_token(data)   # [batch_size, d_model, n_token+1]
data = data.permute(0, 2, 1)    # [batch_size, n_token+1, d_model]

multihead_att = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
data, _ = multihead_att(data, data, data)
# Result shape: [batch_size, n_token+1, d_model]

After applying multi-head attention, I use LayerNorm(d_model) directly on this [batch_size, n_token+1, d_model] tensor. My understanding is that LayerNorm normalizes over the feature dimension, so as long as the feature dimension (d_model) is the last one, it should work fine. But I have two main questions:

  1. If I had stuck with the default multi-head attention format (seq, batch, feature)—that is, using [n_token+1, batch_size, d_model]—would LayerNorm(d_model) still correctly normalize along the feature dimension without permuting the tensor again?
  2. In practice, what’s the best approach for tasks like mine (audio sequence recognition)? Is it recommended to keep the data in [batch_size, seq_len, d_model] format before calling LayerNorm, or is it perfectly acceptable to use (seq, batch, feature) as long as the feature dimension is last?

Both my advisor and I are a bit uncertain. Below are more details from my forward method and the corresponding AttentionBlock implementation for reference:

def forward(self, x: torch.Tensor):
    # Initial: x is [batch_size, d_model, num_tokens]
    x = self.expand(x)
    x = self.concat_cls_token(x)   # [batch_size, d_model, num_tokens+1]
    x = x.permute(0, 2, 1)         # [batch_size, num_tokens+1, d_model]
    x = self.positional_encoder(x)
    x = self.attention_block(x)    # [batch_size, num_tokens+1, d_model]
    x = x.permute(0, 2, 1)         # [batch_size, d_model, num_tokens+1]

    x = self.get_cls_token(x)      # [batch_size, d_model, 1]
    y = self.class_mlp(x)          # [batch_size, n_classes]
    return y

and the implement of AttentionBlock:

class AttentionBlock(nn.Module):
    @staticmethod
    def make_ffn(hidden_dim: int) -> torch.nn.Module:
        return nn.Sequential(
            OrderedDict([
                ("ffn_linear1", nn.Linear(in_features=hidden_dim, out_features=hidden_dim)),
                ("ffn_relu", nn.ReLU()),
                ("ffn_linear2", nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
            ])
        )

    def __init__(self, embed_dim, n_head):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, n_head, batch_first=True)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.feed_forward = self.make_ffn(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x: torch.Tensor):
        attn_output, _ = self.attention(x, x, x)
        x = self.layer_norm1(x + attn_output)
        ff_output = self.feed_forward(x)
        x = self.layer_norm2(x + ff_output)
        return x

Solution

  • From the layernorm documentation:

    torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None)
    
    ""
    Applies Layer Normalization over a mini-batch of inputs.
    
        This layer implements the operation as described in
        the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
    
        .. math::
            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
    
        The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
        is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
        is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
        the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
        :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
        :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
        The standard-deviation is calculated via the biased estimator, equivalent to
        `torch.var(input, unbiased=False)`.
    """
    

    As the documentation says, The mean and standard-deviation are calculated over the last D dimensions. If you create a layer as nn.LayerNorm(d_model), it assumes the input will have a last dimension of shape d_model and apply layernorm over that dimension. The other dimensions of the tensor are not relevant.