tensorflowraggedragged-tensors

Tensorflow: How to use a Ragged Tensor as an index into a normal tensor?


I have a 2D RaggedTensor consisting of indices I want from each row of a full tensor, e.g.:

[
    [0,4],
    [1,2,3],
    [5]
]

into

[
    [200, 305, 400, 20, 20, 105],
    [200, 315, 401, 20, 20, 167],
    [200, 7, 402, 20, 20, 105],
]

gives

[
    [200,20],
    [315,401,20],
    [105]
]

How can I achieve this in the most efficient way (preferably only with tf functions)? I believe that things like gather_nd are able to take RaggedTensors but I cannot figure out how it works.


Solution

  • You can use tf.gather, with the batch_dims keyword argument:

    >>> tf.gather(tensor,indices,batch_dims=1)
    <tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>