pythonnumpytensorflowkerastransformer-model

How to mask inputs with variable size in transformer model when the batches needs to be masked differently?


I'm making a transformer using tensorflow.keras and having issues understanding how the attention_mask works for a MultiHeadAttention layer.

My input is 3-dimensional data. For example, let's assume my whole dataset has 10 elements, each one with length no more than 4:

# whole data
[
  # first item
  [
    [     1,      2,      3],
    [     1,      2,      3],
    [np.nan, np.nan, np.nan],
    [np.nan, np.nan, np.nan],
  ],
  # second item
  [
    [     1,      2,      3],
    [     5,      8,      2],
    [     3,      7,      8],
    [     4,      6,      2],
  ],
  ... # 8 more items
]

So, my mask looks like:

# assume this is a numpy array
mask = [
  [
    [1, 1, 1],
    [1, 1, 1],
    [0, 0, 0],
    [0, 0, 0],
  ],
  [
    [1, 1, 1],
    [1, 1, 1],
    [1, 1, 1],
    [1, 1, 1],
  ],
  ...
]

So the shape of the mask til now is [10, 4, 3]. Let's say I use batch_size = 5. Now, according documentation, attention_mask shape should be [B, T, S] (batch_size, query_size, key_size). In the example case should be [5, 4, 4]?

Question

If the mask is calculated only once, what 5 items should I give as a mask? This sounds counterintuitive to me. How should I build the mask?

According this answer, head_size should be also taken in account, so they also do:

mask = mask[:, tf.newaxis, tf.newaxis, :]

What I've tested

The only time I manage to run the transformer successfully using the attention_mask is when I do:

mask = np.ones((batch_size, data.shape[1], data.shape[2]))
mask = mask[:, tf.newaxis, tf.newaxis, :]

Obviously that mask makes no sense, because it is all ones, but it was just to test if it had the correct shape.

The model

I'm using practically the same code from the keras example transformer for time series classification

def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0.0, mask=None):
    # Normalization and Attention
    x = layers.LayerNormalization(epsilon=1e-6)(inputs)
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(x, x, attention_mask=mask)
    x = layers.Dropout(dropout)(x)
    res = x + inputs

    # Feed Forward Part
    x = layers.LayerNormalization(epsilon=1e-6)(res)
    x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    return x + res


def build_model(
    n_classes,
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0.0,
    mlp_dropout=0.0,
    input_mask=None,
) -> keras.Model:
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout, input_mask)

    x = layers.GlobalAveragePooling2D(data_format="channels_first")(x)
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    outputs = layers.Dense(n_classes, activation="softmax")(x)
    return keras.Model(inputs, outputs)

Solution

  • After a little research and seeing several transformer model examples this is what solved the problem for me.

    1. Create a custom TransformerBlock layer that supports masking
    2. Add a mask parameter in the call method of the TransformerBlock and reshape it there.
    3. Add a Masking layer before the TransformerBlock

    Code:

    class TransformerBlock(layers.Layer):
        def __init__(self, head_size, num_heads, ff_dim, ff_dim2, rate=0.1):
            super().__init__()
            self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=head_size)
            self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
            self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
            self.dropout1 = layers.Dropout(rate)
            self.dropout2 = layers.Dropout(rate)
            self.conv1 = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")
            self.conv2 = layers.Conv1D(filters=ff_dim2, kernel_size=1)
            self.supports_masking = True
    
        def call(self, inputs, training, mask=None):
            padding_mask = None
            if mask is not None:
                padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype="int32")
    
            out_norm1 = self.layernorm1(inputs, training=training)
            out_att = self.att(
                out_norm1, out_norm1, training=training, attention_mask=padding_mask
            )
            out_drop1 = self.dropout1(out_att, training=training)
            res = out_drop1 + inputs
            out_norm2 = self.layernorm2(res, training=training)
            out_conv1 = self.conv1(out_norm2, training=training)
            out_drop2 = self.dropout2(out_conv1, training=training)
            out_conv2 = self.conv2(out_drop2, training=training)
            return out_conv2 + res
    
    def build_model(
        n_classes,
        input_shape,
        head_size,
        num_heads,
        ff_dim,
        num_transformer_blocks,
        mlp_units,
        dropout=0.0,
        mlp_dropout=0.0,
        mask=None,
    ) -> keras.Model:
        inputs = keras.Input(shape=input_shape)
        _x = inputs
        if mask is not None:
            _x = layers.Masking(mask_value=mask)(_x)
        for _ in range(num_transformer_blocks):
            _x = TransformerBlock(
                head_size,
                num_heads,
                ff_dim,
                inputs.shape[-1],
                dropout,
            )(_x)
    
        _x = layers.GlobalAveragePooling2D(data_format="channels_first")(_x)
        for dim in mlp_units:
            _x = layers.Dense(dim, activation="relu")(_x)
            _x = layers.Dropout(mlp_dropout)(_x)
        outputs = layers.Dense(n_classes, activation="softmax")(_x)
        return keras.Model(inputs, outputs)