pytorchtransformer-model

Transformers: Cross Attention Tensor Shapes During Inference Mode


Having been trying to figure this out for a while. I found this similar question but I don't think the proposed answer actually addresses the question.

During inference mode of an Encoder/Decoder Transformer, my understanding is we don't pre-pad the Decoder input sequence to match the Encoder sequence length (i.e. pass in [start_id,] and not [start_id, pad_id, pad_id, ...])

I might be missing something but when I don't pre-pad, the attention mechanism cannot correctly compute the matrix multiplication because the Decoder input is of seq_length = 1 while the Encoder seq_length is > 1 (T). For reference (see attached pic), I identified the tensor shapes during each step and you can see where the last matmul step cannot be performed given incompatible tensor shapes.

What am I missing? Am I suppose to pre-pad the Decoder input? Or do I truncate the Encoder output to match the Decoder length? Something else?

enter image description here


Solution

  • You've got the order mixed up. For cross attention in an encoder/decoder transformer, the query comes from the decoder, and the key/value come from the encoder.

    You can check this in the original Attention Is All You Need paper. In section 3.2.3:

    In "encoder-decoder attention" layers, the queries come from the previous
    decoder layer, and the memory keys and values come from the output of the
    encoder.
    

    Say our decoder query is of size (bs, sl_d, n) and our key/value from the encoder is of size (bs, sl_e, n).

    When we compute attention, our attention map from the query and key will be of size (bs, sl_d, sl_e). We then matmul the attention map with the value of size (bs, sl_e, n), giving a result of size (bs, sl_d, n), which matches the initial query size from the decoder.

    I omitted the multi-head aspect for brevity, but the same applies there just with an extra axis and reshaping of the d_model dimension.