In the TensorFlow Keras implementation of Multi-Head Attention, instead of evaluating the numerator first like in
they evaluate Q/√dₖ first and put comment
Note: Applying scalar multiply at the smaller end of einsum improves XLA performance, but may introduce slight numeric differences in the Transformer attention head.
How is it faster this way? Wouldn't the division after einsum be equally as fast?
What the comment suggest is that the the number of elements in key
is less than the number of elements in query
or attention_scores
in the following equation.
attention_scores = tf.einsum(self._dot_product_equation, key, query)
Given the dimensions
query: Projected query `Tensor` of shape `(B, T, N, key_dim)`.
key: Projected key `Tensor` of shape `(B, S, N, key_dim)`.
Assuming that _dot_product_equation
is simply doing the batched matrix multiplication, if Q is T x N
, and Q is S x N
, the product Q @ K.T
is T x S
, if S > N
the number of multiplications is expected to be smaller on the left.
But either way that should not be the dominant part except if S > T * N
(or XLA has a bug).