tensorflowkerasnlptranslationattention-model

Difference between MultiheadAttention and Attention layer in Tensorflow


What is the difference between the following layers in Tensorflow: tf.keras.layers.Attention, tf.keras.layers.MultiHeadAttention and tf.keras.layers.AdditiveAttention?

Also how to implement tf.keras.layers.MultiHeadAttention using fundamental layers like Dense, Add, LayerNormalization, etc? I want to understand the exact operations happening inside this tutorial.


Solution

  • https://paperswithcode.com/ is a good resource to understand nuance of different deep learning terminologies and implementation

    The general definition of attention mechanism in the transformer model:

    Attention Mechanisms are a component used in neural networks to model long-range interaction, for example across a text in NLP. The key idea is to build shortcuts between a context vector and the input, to allow a model to attend to different parts. - paperswithcode

    enter image description here

    In my own words, the "shortcuts" attention is created by doing sequential matrix multiplications of the "query" (inputs) to "value" (the target that you want to map the inputs to), and between there, there is a "key" that acts like a signal that the query theoretically should make use of to project the query to the value. And the commonnoutput of the attention mechanism is a vector/matrix/tensor representation of that encodes this shortcut.

    There are many variants on these "shortcuts" (aka attention mechanisms) that researchers have tried to find the optimal connection from query + key -> value. See list on https://paperswithcode.com/methods/category/attention-mechanisms-1

    Attention vs MultiHeadAttention

    enter image description here

    In my own words, the main differentiator between general Attention and MultiHeadAttention is the redundancy put into "MultiHead" inputs. If single head (general) attention maps one Q + K to V, think of multi-head as creating multiple Qs that corresponds to multiple Ks and you want to create the shortcut to multiple corresponding Vs.

    In code, assuming that the initialization for Attention, MultiHeadAttention are the same, the output_tensor values for the following should be the same:

    import tensorflow as tf
    from tensorflow.keras.layers import Attention, MultiHeadAttention
    
    
    layer = MultiHeadAttention(num_heads=1, key_dim=2)
    target = tf.keras.Input(shape=[8, 16])
    source = tf.keras.Input(shape=[4, 16])
    output_tensor, weights = layer(target, source,
                                   return_attention_scores=True)
    
    
    
    
    layer_vanilla = Attention()
    target_vanilla = tf.keras.Input(shape=[8, 16])
    source_vanilla = tf.keras.Input(shape=[4, 16])
    output_tensor_vanilla, weights_vanilla = layer_vanilla([target_vanilla, source_vanilla],
                                   return_attention_scores=True)
    
    print(output_tensor)
    print(output_tensor_vanilla)
    

    [out]:

    KerasTensor(type_spec=TensorSpec(shape=(None, 8, 16), dtype=tf.float32, name=None), name='multi_head_attention_6/attention_output/add:0', description="created by layer 'multi_head_attention_6'")
    
    KerasTensor(type_spec=TensorSpec(shape=(None, 8, 16), dtype=tf.float32, name=None), name='attention_3/MatMul_1:0', description="created by layer 'attention_3'")
    
    

    https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention

    Attention vs AdditiveAttention

    Additive Attention is an interesting one; it is the OG attention mechanism:

    Additive Attention, also known as Bahdanau Attention, uses a one-hidden layer feed-forward network to calculate the attention alignment score

    Details: https://paperswithcode.com/method/additive-attention

    Before "IMOW", lets look at the code:

    from tensorflow.keras.layers import AdditiveAttention
    
    layer_bdn = AdditiveAttention()
    target_bdn = tf.keras.Input(shape=[8, 16])
    source_bdn = tf.keras.Input(shape=[4, 16])
    output_tensor_bdn, weights_bdn = layer_bdn([target_bdn, source_bdn],
                                   return_attention_scores=True)
    
    print(output_tensor_bdn)
    
    

    [out]:

    <KerasTensor: shape=(None, 8, 16) dtype=float32 (created by layer 'additive_attention')>
    

    Comparing the implementations:

    https://www.diffchecker.com/5i9Viqm9/

    enter image description here

    The general Attention has:

    scores = self.concat_score_weight * tf.reduce_sum(
                        tf.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1
                    )
    

    where the initializer can be set to initializer="ones" if if self.score_mode == "concat":

            if self.score_mode == "concat":
                self.concat_score_weight = self.add_weight(
                    name="concat_score_weight",
                    shape=(),
                    initializer="ones",
                    dtype=self.dtype,
                    trainable=True,
                )
    

    but the AdditiveAttention uses the glorot initializer if the self.use_scale is set to True:

            if self.use_scale:
                self.scale = self.add_weight(
                    name="scale",
                    shape=[dim],
                    initializer="glorot_uniform",
                    dtype=self.dtype,
                    trainable=True,
                )
    

    There are further nuances in the implementation though.

    In my own words, additive attention is the earlier definition of the general attention mechanism. They achieve the same purpose of single headed attention. And if the initializations and scaling are set equally, additive attention == general attention.

    Q: Then what should I be using when choosing the attention layer?

    A: Depends on what is the ultimate goal, if the goal is replicate the original Bahdanau paper, then additive attention would be the closest. If not, then the vanilla attention is most probably what you want.

    Q: What about multi-head?

    A: In most cases, you will always use multi-head attention since