If I have the following dataset built from a ragged tensor, how can I get the maximum length (4 in this example) of all elements?
ds = tf.data.Dataset.from_tensor_slices(
tf.ragged.constant([[1, 2, 3, 4], [], [5, 6, 7], [8], []]))
We can use the function reduce
:
ds.reduce(0, lambda state, value: tf.math.maximum(state, len(value))).numpy()