I’m working on an audio recognition task using a Transformer-based model in PyTorch. My input features are generated by a CNN-based embedding layer and have the shape [batch_size, d_model, n_token], where n_token is the sequence length and d_model is the feature dimension.
By default, nn.MultiheadAttention (when batch_first=False) expects input in the shape (seq, batch, feature). To make things more intuitive, I chose to set batch_first=True and then permute my data from [batch_size, d_model, n_token] to [batch_size, n_token, d_model] so that the time dimension comes before the feature dimension. Here’s a simplified code snippet:
# Original shape: [batch_size, d_model, n_token]
data = concat_cls_token(data) # [batch_size, d_model, n_token+1]
data = data.permute(0, 2, 1) # [batch_size, n_token+1, d_model]
multihead_att = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
data, _ = multihead_att(data, data, data)
# Result shape: [batch_size, n_token+1, d_model]
After applying multi-head attention, I use LayerNorm(d_model) directly on this [batch_size, n_token+1, d_model] tensor. My understanding is that LayerNorm normalizes over the feature dimension, so as long as the feature dimension (d_model) is the last one, it should work fine. But I have two main questions:
Both my advisor and I are a bit uncertain. Below are more details from my forward method and the corresponding AttentionBlock implementation for reference:
def forward(self, x: torch.Tensor):
# Initial: x is [batch_size, d_model, num_tokens]
x = self.expand(x)
x = self.concat_cls_token(x) # [batch_size, d_model, num_tokens+1]
x = x.permute(0, 2, 1) # [batch_size, num_tokens+1, d_model]
x = self.positional_encoder(x)
x = self.attention_block(x) # [batch_size, num_tokens+1, d_model]
x = x.permute(0, 2, 1) # [batch_size, d_model, num_tokens+1]
x = self.get_cls_token(x) # [batch_size, d_model, 1]
y = self.class_mlp(x) # [batch_size, n_classes]
return y
and the implement of AttentionBlock:
class AttentionBlock(nn.Module):
@staticmethod
def make_ffn(hidden_dim: int) -> torch.nn.Module:
return nn.Sequential(
OrderedDict([
("ffn_linear1", nn.Linear(in_features=hidden_dim, out_features=hidden_dim)),
("ffn_relu", nn.ReLU()),
("ffn_linear2", nn.Linear(in_features=hidden_dim, out_features=hidden_dim))
])
)
def __init__(self, embed_dim, n_head):
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim, n_head, batch_first=True)
self.layer_norm1 = nn.LayerNorm(embed_dim)
self.feed_forward = self.make_ffn(embed_dim)
self.layer_norm2 = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor):
attn_output, _ = self.attention(x, x, x)
x = self.layer_norm1(x + attn_output)
ff_output = self.feed_forward(x)
x = self.layer_norm2(x + ff_output)
return x
From the layernorm documentation:
torch.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None)
""
Applies Layer Normalization over a mini-batch of inputs.
This layer implements the operation as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated over the last `D` dimensions, where `D`
is the dimension of :attr:`normalized_shape`. For example, if :attr:`normalized_shape`
is ``(3, 5)`` (a 2-dimensional shape), the mean and standard-deviation are computed over
the last 2 dimensions of the input (i.e. ``input.mean((-2, -1))``).
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
The standard-deviation is calculated via the biased estimator, equivalent to
`torch.var(input, unbiased=False)`.
"""
As the documentation says, The mean and standard-deviation are calculated over the last D
dimensions. If you create a layer as nn.LayerNorm(d_model)
, it assumes the input will have a last dimension of shape d_model
and apply layernorm over that dimension. The other dimensions of the tensor are not relevant.