machine-learningdeep-learningnlprecurrent-neural-networkattention-model

Mismatch between computational complexity of Additive attention and RNN cell


According to Attention is all you need paper: Additive attention (The classic attention use in RNN by Bahdanau) computes the compatibility function using a feed-forward network with a single hidden layer. While the two are similar in theoretical complexity, ...

Indeed, we can see here that the computational complexity of additive attention and dot-prod (transformer attention) are both n²*d.

However, if we look closer at additive attention, it is in fact a RNN cell which have a computational complexity of n*d² (according to the same table).

Thus, shouldn't the computational complexity of additive attention be n*d² instead of n²*d ?


Solution

  • Your claim that additive attention is in fact a RNN cell is what is leading you astray. Additive attention is implemented using a fully-connected shallow (1 hidden layer) feedforward neural network "between" the encoder and decoder RNNs as shown below and described in the original paper by Bahdanau et al. (pg. 3) [1]:

    ... an alignment model which scores how well the inputs around position j and the output at position i match. The score is based on the RNN hidden state s_i − 1 (just before emitting y_i, Eq. (4)) and the j-th annotation h_j of the input sentence.

    We parametrize the alignment model a as a feedforward neural network which is jointly trained with all the other components of the proposed system...

    Attention mechanism diagram credit to Nir Arbel

    Figure 1: Attention mechanism diagram from [2].

    Thus, the alignment scores are calculated by adding the outputs of the decoder hidden state to the encoder outputs. So the additive attention is not a RNN cell.

    References

    [1] Bahdanau, D., Cho, K. and Bengio, Y., 2014. Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473.

    [2] Arbel, N., 2019. Attention in RNNs. Medium blog post.