pythontensorflowtensorflow2.0tensorragged

Tensorflow gather, concatenate and then pad operation?


I have a 2D tensor in TensorFlow 2 (python). How can I pick-out and concatenate rows based on a ragged array of row indices and then pad shorter rows with zeros so that all rows end up with the same length?

Here is an example of what I have:

data = tf.constant([
            [300, 301, 302],
            [100, 101, 102],
            [200, 201, 202],
            [120, 121, 122],
            [210, 211, 212],
            [410, 411, 412],
            [110, 111, 112],
            [400, 401, 402],
        ], dtype=tf.float32)

row_ids = [ [ 1, 6, 3 ], [ 2, 4 ], [ 0 ], [ 7, 5] ]

And this is what I would like to get:

desired_result = tf.constant([
        [ 100, 101, 102, 110, 111, 112, 120, 121, 122],
        [ 200, 201, 202, 210, 211, 212,   0,   0,   0],
        [ 300, 301, 302,   0,   0,   0,   0,   0,   0],
        [ 400, 401, 402, 410, 411, 412,   0,   0,   0]
    ], 
    dtype=tf.float32
)

I have attempted to find a way with tf.RaggedTensor.from_value_rowids() and tf.gather_nd() with tf.concat() but without any success.

I do need to backpropagate through this operation and, therefore, I need to stick to TensorFlow 2 operations.

Any suggestions would be greatly appreciated! Thanks!


Solution

  • IIUC, you can actually solve this task more simply:

    import tensorflow as tf
    
    data = tf.constant([
                [300, 301, 302],
                [100, 101, 102],
                [200, 201, 202],
                [120, 121, 122],
                [210, 211, 212],
                [410, 411, 412],
                [110, 111, 112],
                [400, 401, 402],
            ], dtype=tf.float32)
    
    row_ids = tf.ragged.constant([ [ 1, 6, 3 ], [ 2, 4 ], [ 0 ], [ 7, 5] ])
    
    t = tf.gather(data, row_ids).to_tensor()
    t = tf.reshape(t, [tf.shape(t)[0], tf.reduce_prod(tf.shape(t)[1:])])
    
    <tf.Tensor: shape=(4, 9), dtype=float32, numpy=
    array([[100., 101., 102., 110., 111., 112., 120., 121., 122.],
           [200., 201., 202., 210., 211., 212.,   0.,   0.,   0.],
           [300., 301., 302.,   0.,   0.,   0.,   0.,   0.,   0.],
           [400., 401., 402., 410., 411., 412.,   0.,   0.,   0.]],
          dtype=float32)>