python-3.xtensorflowrecurrent-neural-networktensorflow-slim

Unable to understand tf.nn.raw_rnn


In the official documentation of tf.nn.raw_rnn we have emit structure as the third output of loop_fn when the loop_fn is run for the first time.

Later on the emit_structure is used to copy tf.zeros_like(emit_structure) to the minibatch entries that are finished by emit = tf.where(finished, tf.zeros_like(emit_structure), emit).

my lack of understanding or lousy documentation on google's part is: emit structure is None so tf.where(finished, tf.zeros_like(emit_structure), emit) is going to throw a ValueError as tf.zeros_like(None) does so. Can somebody please fill in what I am missing here?


Solution

  • Yep, the doc is rather confusing in this place. If you look at the internals of tf.nn.raw_rnn, the key term there is "in pseudo-code", so the example in the doc isn't accurate.

    The exact source code looks like this (may differ depending on your tensorflow version):

    if emit_structure is not None:
      flat_emit_structure = nest.flatten(emit_structure)
      flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                        array_ops.shape(emit) for emit in flat_emit_structure]
      flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
    else:
      emit_structure = cell.output_size
      flat_emit_size = nest.flatten(emit_structure)
      flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)
    

    So it handles the case when emit_structure is None and simply takes the value cell.output_size. That's why nothing really breaks.