Having been trying to figure this out for a while. I found this similar question but I don't think the proposed answer actually addresses the question.
During inference mode of an Encoder/Decoder Transformer, my understanding is we don't pre-pad the Decoder input sequence to match the Encoder sequence length (i.e. pass in [start_id,]
and not [start_id, pad_id, pad_id, ...]
)
I might be missing something but when I don't pre-pad, the attention mechanism cannot correctly compute the matrix multiplication because the Decoder input is of seq_length = 1 while the Encoder seq_length is > 1 (T
). For reference (see attached pic), I identified the tensor shapes during each step and you can see where the last matmul step cannot be performed given incompatible tensor shapes.
What am I missing? Am I suppose to pre-pad the Decoder input? Or do I truncate the Encoder output to match the Decoder length? Something else?
You've got the order mixed up. For cross attention in an encoder/decoder transformer, the query
comes from the decoder, and the key
/value
come from the encoder.
You can check this in the original Attention Is All You Need paper. In section 3.2.3:
In "encoder-decoder attention" layers, the queries come from the previous
decoder layer, and the memory keys and values come from the output of the
encoder.
Say our decoder query
is of size (bs, sl_d, n)
and our key
/value
from the encoder is of size (bs, sl_e, n)
.
When we compute attention, our attention map from the query
and key
will be of size (bs, sl_d, sl_e)
. We then matmul the attention map with the value
of size (bs, sl_e, n)
, giving a result of size (bs, sl_d, n)
, which matches the initial query size from the decoder.
I omitted the multi-head aspect for brevity, but the same applies there just with an extra axis and reshaping of the d_model
dimension.