pythontensorflowtensorragged

How to reshape a ragged tensor?


Suppose you have stacked two sequences of 3-dimensional embeddings into a single ragged tensor:

import tensorflow as tf

def foo(*args):
    n_elements = tf.reduce_prod(args)
    return tf.range(n_elements, dtype=tf.float32).reshape(args)

c = tf.ragged.stack((foo(2, 3), foo(5, 3)), axis=0)
assert c.shape == [2, None, None]

How to cast c to shape [2, None, 3] (because you know this tensor is of this shape)?


Solution

  • Try using tf.RaggedTensor.from_row_splits:

    tf.RaggedTensor.from_row_splits(
        values=c.merge_dims(0, 1).to_tensor(),
        row_splits=[0, 3, 7]).shape
    
    (2, None, 3)