I am going through Tensorflow's tutorial on Neural Machine Translation using Attention mechanism.
It has the following code for the Decoder :
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(self.dec_units,
return_sequences=True,
return_state=True,
recurrent_initializer='glorot_uniform')
self.fc = tf.keras.layers.Dense(vocab_size)
# used for attention
self.attention = BahdanauAttention(self.dec_units)
def call(self, x, hidden, enc_output):
# enc_output shape == (batch_size, max_length, hidden_size)
context_vector, attention_weights = self.attention(hidden, enc_output)
# x shape after passing through embedding == (batch_size, 1, embedding_dim)
x = self.embedding(x)
# x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
# passing the concatenated vector to the GRU
output, state = self.gru(x)
# output shape == (batch_size * 1, hidden_size)
output = tf.reshape(output, (-1, output.shape[2]))
# output shape == (batch_size, vocab)
x = self.fc(output)
return x, state, attention_weights
What I don't understand here is that, the GRU cell of the decoder is not connected to the encoder by initializing it with the last hidden state of the encoder.
output, state = self.gru(x)
# Why is it not initialized with the hidden state of the encoder ?
As per my understanding, there is a connection between the encoder and decoder, only when the decoder is initialized with the "Thought vector" or the last hidden state of the encoder.
Why is that missing in Tensorflow's official tutorial ? Is it a bug ? Or am I missing something here ?
Could someone help me understand ?
This is very well summarized by this detailed NMT guide, which compares the classic seq2seq NMT against the encoder-decoder attention-based NMT architectures.
Vanilla seq2seq: The decoder also needs to have access to the source information, and one simple way to achieve that is to initialize it with the last hidden state of the encoder, encoder_state.
Attention-based encoder-decoder: Remember that in the vanilla seq2seq model, we pass the last source state from the encoder to the decoder when starting the decoding process. This works well for short and medium-length sentences; however, for long sentences, the single fixed-size hidden state becomes an information bottleneck. Instead of discarding all of the hidden states computed in the source RNN, the attention mechanism provides an approach that allows the decoder to peek at them (treating them as a dynamic memory of the source information). By doing so, the attention mechanism improves the translation of longer sentences.
In both cases, you can use teacher forcing to better train the model.
TLDR; the attention mechanism is what helps the decoder "peak" into the encoder instead of you explicitly passing what the encoder is doing to the decoder.