In the official documentation of tf.nn.raw_rnn
we have emit structure as the third output of loop_fn
when the loop_fn
is run for the first time.
Later on the emit_structure is used to copy tf.zeros_like(emit_structure)
to the minibatch entries that are finished by emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
.
my lack of understanding or lousy documentation on google's part is: emit structure is None
so tf.where(finished, tf.zeros_like(emit_structure), emit)
is going to throw a ValueError as tf.zeros_like(None)
does so. Can somebody please fill in what I am missing here?
Yep, the doc is rather confusing in this place. If you look at the internals of tf.nn.raw_rnn
, the key term there is "in pseudo-code", so the example in the doc isn't accurate.
The exact source code looks like this (may differ depending on your tensorflow version):
if emit_structure is not None:
flat_emit_structure = nest.flatten(emit_structure)
flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
array_ops.shape(emit) for emit in flat_emit_structure]
flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
emit_structure = cell.output_size
flat_emit_size = nest.flatten(emit_structure)
flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
So it handles the case when emit_structure is None
and simply takes the value cell.output_size
. That's why nothing really breaks.