pythontensorflowragged-tensors

Sample from ragged tensor


I have a raggedTensor of row_lens going from 1 to up to 10k. I would like to select elements randomly from it with an upper limit on the number per row in a scalable way. Like in this example:

vect = [[1,2,3],[4,5][6],[7,8,9,10,11,12,13]]
limit = 3
sample(vect, limit)

-> output: [[1,2,3],[4,5],[6],[7,9,11]]

My idea was to select * in case len_row < limit and randomly in the other case. I wonder if this can be done with less than batch_size complexity with some tensorflow operations?


Solution

  • You can try using tf.map_fn in graph mode:

    import tensorflow as tf
    
    vect = tf.ragged.constant([[1,2,3],[4,5],[6],[7,8,9,10,11,12,13]])
    
    @tf.function
    def sample(x, samples=3):
      length = tf.shape(x)[0]
      x = tf.cond(tf.less_equal(length, samples), lambda: x, lambda: tf.gather(x, tf.random.shuffle(tf.range(length))[:samples]))
      return x
    
    c = tf.map_fn(sample, vect)
    
    <tf.RaggedTensor [[1, 2, 3], [4, 5], [6], [12, 7, 9]]>
    

    Note that tf.vectorized_map would probably be faster, but there is a current bug regarding this function and ragged tensors. The use of tf.while_loop is also an option.