tensorflowtime-seriesmaskingattention-modelmultivariate-time-series

How to properly mask MultiHeadAttention for sliding window time series data


I have data in the shape (batch, seq_len, features) that is a time series sliding window. In essence, I'm using the most recent seq_len steps in order to predict a single target variable. This means that the output of the last seq_len value in my MultiHeadAttention layer should be the predicted value.

I've made many attempts at generating different attention_masks to use in Keras' MultiHeadAttention but none of them quite capture the behavior I want, inevitably leading to poor results. Ultimately I only want the importance of each seq_len query step relative to the last key step. It's basically an autoregressive additive model using the transformer architecture (only using the encoder). The last step is to tf.reduce_sum over the entire seq_len in order to get the output.

Future modifications to the attention layer might be using teacher forcing which should further improve the learning phase and reduce the obvious influence of the correlation of the last value with itself, but I can't figure out how to correctly mask in the first place for continuous time series data like this. To be clear this is NOT an NLP model.


Solution

  • I am also trying to mask data with MultiHeadAttention for a multi-variate time series in the shape (batch, seq_len, features). I want to mask future values in the decoder during training and I am not using any padding.

    I'm currently using the get_causal_attention_mask function in the English-to-Spanish translation with a sequence-to-sequence Transformer Keras tutorial.

    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, sequence_length = input_shape[0], input_shape[1]
    
        i = tf.range(sequence_length)[:, tf.newaxis]
        j = tf.range(sequence_length)
    
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
    
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
            axis=0,
        )
    
        return tf.tile(mask, mult)
    

    I only apply this masking to the decoder self-attention layers. If it is disabled I see overfitting.

    My code is in this notebook.


    New in Tensorflow/Keras version 2.10 is the use_causal_mask call option. The documentation describes it like so:

    use_causal_mask: A boolean to indicate whether to apply a causal mask to prevent tokens from attending to future tokens (e.g., used in a decoder Transformer).

    The Add support for automatic mask handling in MultiHeadAttention layer GitHub issue contains more details.

    The tensorflow transformer tutorial illustrates use_causal_mask usage for NLP translation.

    I'm developing on google colab. Unfortunately, after an upgrade to 2.10 I get session crashes or CUDA/cuDNN errors. These problems persist after disabling use_causal_mask, so probably an upgrade snafu.

    I'm probably going to drop this and move on. Hope this answer helps someone.