transformer-modelencoderdecoder

What are the inputs of the first decoder in the transformer architecture


In the transformer architecture from the original paper, I referred many texts but I couldn't solve this insight.

Lets start with input sentence, "The cat jumped."

My understanding is, each word is parallely processed by each encoder. That is, taking the word cat for example, its embedding is produced, then its positionally encoded to produce another vector. Then, based on the attention computations, a final vector is produced. This is then passed to the feedforward part of the encoder. This process is done in parallel for the all other words of the input sentence. This is repeated until the 6th encoder. Hence, at the last encoder, we have three outputs for each of the words of the sentence. If this is the case, what does the first decoder receive as input? All the three outputs one after the other, in parallel, concatenated or merged? I don't think concatenation is used since the decoder has fixed input size.


Solution

  • After referring to excellent notes listed below in ref section, i figured out the answer and posting it for other newbies in the future. The gif in ref 3 (the transformer tutorial from google) gives an accurate representation of what all the decoders process and also the "Run Inference" section of the tutorial in ref 3 describes what happens. What happens in general is, all decoders make use of only the K,V outputs (of all the words "The cat jumped") from the last encoder. This is to calculate attention values on the encoder representation of the input sentence "The cat jumped." . The decoding process is exactly the same as what happens in the encoding process but differs only by the addition of cross-attention. The inputs to all decoders is then, the current output sequence that is being generated by the decoder itself, which is initially just the [start] token and then the attention values K and V of all the words "The cat jumped" from the last encoder.

    Here is how the decoding process works. The input sentence "The cat jumped" results in the last encoder output of K,V values for all the three words. Then, the [start] token is given as input to the decoder. The self-attention of the decoder is calculated using only the initial output sequence so far, which is the [start] token. Then the query Q values of the [start] token is used with the K,V values of the last encoder output to calculate the attention wrt the encoded representation of the sentence "The cat jumped". We have then four attention values for the start token which are combined to create vector representation of the [start] token, which is fed to the feedforward network of the decoder, generating the first output word, ow1 (output word 1) ,which could be the first translated word in some language. Then the same process is repeated with the sequence [start] ow1. Then, the output from the decoder that processed [ow1] would give rise to the next generated word ow2. This is repeated until the [end] token is generated.

    Ref:

    1. https://ai.stackexchange.com/questions/36688/last-linear-layer-of-the-decoder-of-a-transformer

    2. What does the transformer decoder attend to at the last linear layer?: https://www.reddit.com/r/learnmachinelearning/comments/18u929u/comment/kfoktol/

    3. The gif in the tutorial https://www.tensorflow.org/text/tutorials/transformer.
      The actual gif: https://www.tensorflow.org/images/tutorials/transformer/apply_the_transformer_to_machine_translation.gif

    4. http://jalammar.github.io/illustrated-transformer/