so basically I have a ragged tensor (e.g. [[1, 2, 3], [4, 5], [6]]) and I want to concat them with a special character in between them, like an specific number, say 0. So the result would be [[1, 2, 3, 0, 4, 5, 0, 6]]. So this is sth like joining strings but I want to do it with ragged integers. I have no solutions for this to be able to turn it into a @tf.function. Also the purpose of this is to concatenate tokens of a documents sentences, and that special character is to indicate where a sentence ends and another starts.
Try using tf.concat
and ragged.merge_dims
:
import tensorflow as tf
ragged = tf.ragged.constant([[1, 2, 3], [4, 5], [6]])
rows = ragged.bounding_shape()[0]
ragged = tf.concat([ragged, tf.concat([tf.expand_dims(tf.repeat([0], repeats=rows-1), axis=-1), tf.ragged.constant([[]], dtype=tf.int32)], axis=0)], axis=-1)
ragged = tf.expand_dims(ragged.merge_dims(0, 1), axis=0)
print(ragged)
# tf.Tensor([[1 2 3 0 4 5 0 6]], shape=(1, 8), dtype=int32)