pythontensorflowdeep-learningrecurrent-neural-network

Accumulating output from a graph using tf.while_loop


I have an RNN that is stacked on top of a CNN. The CNN was created and trained separately. To clarify things, let's suppose the CNN takes input in the form of a [BATCH SIZE, H, W, C] placeholder (H = height, W = width, C = number of channels).

Now, when stacked on top of the RNN, the overall input to the combined network will have the shape: [BATCH SIZE, TIME SEQUENCE, H, W, C], i.e. each sample in the minibatch consists of TIME_SEQUENCE many images. Moreover, the time sequences are variable in length. There is a separate placeholder called sequence_lengths with shape [BATCH SIZE] that contains scalar values corresponding to the length of each sample in the minibatch. The value of TIME SEQUENCE corresponds to the maximum possible time sequence length, and for samples with smaller lengths, the remaining values are padded with zeros.

What I want to do

I want to accumulate the output from the CNN in a tensor of shape [BATCH SIZE, TIME SEQUENCE, 1] (the last dimension just contains the final score output by the CNN for each time sample for each batch element) so that I can forward this entire chunk of information to the RNN that is stacked on top of the CNN. The tricky thing is, I also want to be able to back-propagate the error from the RNN to the CNN (the CNN is already pre-trained, but I would like to fine-tune the weights a bit), so I have to stay inside the graph, i.e. I can't make any calls to session.run().

Inside my_cnn_model.process_input, I'm just passing the input through a vanilla CNN. All the variables created in it are with tf.AUTO_REUSE, so that should ensure that the while loop reuses the same weights for all the loop iterations.

The exact problem

image_output_sequence is a variable, but somehow when tf.while_loop calls the body method, it gets turned into a Tensor type object to which assignments can't be made. I get the error message: Sliced assignment is only supported for variables

This problem persists even if I use another format like using a tuple of BATCH SIZE Tensors each with dimensions [TIME SEQUENCE, H, W, C].

I'm open to a complete redesign of the code as well, as long as it gets the job done nicely.


Solution

  • The solution is to use an object of type TensorArray, which is specifically made to address such problems. The following line:

    image_output_sequence = tf.Variable(tf.zeros([batch_size, max_sequence_length, 1], tf.float32))
    

    is replaced by:

    image_output_sequence = tf.TensorArray(size=batch_size, dtype=tf.float32, element_shape=[max_sequence_length, 1], infer_shape=True)
    

    TensorArray doesn't actually require a fixed shape for each element, but for my case it is fixed, so it's better to enforce it.

    Then inside the body function, replace this:

    ios[lc].assign(padded_cnn_features)
    

    with:

    ios = ios.write(lc, padded_cnn_output)
    

    Then after the tf.while_loop statement, the TensorArray can be stacked to form a regular Tensor for further processing:

    stacked_tensor = result.stack()