pythontensorflowkerastf.dataset

How to find out the maximum length of a dimension in a ragged dataset


If I have the following dataset built from a ragged tensor, how can I get the maximum length (4 in this example) of all elements?

ds = tf.data.Dataset.from_tensor_slices(
    tf.ragged.constant([[1, 2, 3, 4], [], [5, 6, 7], [8], []]))

Solution

  • We can use the function reduce:

    ds.reduce(0, lambda state, value: tf.math.maximum(state, len(value))).numpy()