I am working on keras seq2seq example here:https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html What I have understood from the text is in decoder model each cell's output is input to the next cell. However I didnt understand implementing this recursion to the model.In the link it makes the decoder model as follows.
decoder_model = Model(
[decoder_inputs] + decoder_states_inputs,
[decoder_outputs] + decoder_states)
How does this syntax work to tell the model that each cells output is input to next cell? In general how does this syntax work?
EDIT: When you check keras.Model documentation you will realize that a model can take a list of keras.Input objects as input argument, notice that [decoder_inputs] + decoder_states_inputs is a list.
If you look at the documentation for the Keras Model
class here, you'll see that the Model()
function takes in inputs
and outputs
as its first and second arguments respectively (Model(inputs, outputs)
). This specifies the input and output layers of the model (in your case, a decoder that will be used in the inference loop of the decode_sequence()
function at the end of the article you linked).
To elaborate more on the code snippet you posted, you are providing decoder_inputs
and decoder_states_inputs
together as the inputs
argument of Model(inputs, outputs)
to specify the input layer of the decoder model:
decoder_inputs
is an Input
object (Keras tensor) with length num_decoder_tokens
, instantiated using the Input()
function (see Input) that simply accepts the input tokens (characters).
Similarly, decoder_states_inputs
is a list of two Input
tensors for the decoder's hidden input state and cell state, both of length latent_dim
.
And again, you provide the decoder_outputs
and decoder_states
together as the outputs
argument of Model(inputs, outputs)
to specify the output layer of the model:
decoder_outputs
ends up being a densely connected NN layer used for output activation (see Dense).decoder_states
is a list containing the hidden state state_h
and cell state state_c
of the decoder_lstm
.