I have a 2D RaggedTensor consisting of indices I want from each row of a full tensor, e.g.:
[
[0,4],
[1,2,3],
[5]
]
into
[
[200, 305, 400, 20, 20, 105],
[200, 315, 401, 20, 20, 167],
[200, 7, 402, 20, 20, 105],
]
gives
[
[200,20],
[315,401,20],
[105]
]
How can I achieve this in the most efficient way (preferably only with tf
functions)? I believe that things like gather_nd
are able to take RaggedTensors but I cannot figure out how it works.
You can use tf.gather
, with the batch_dims
keyword argument:
>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>