pythontensorflow2.0tensorflow-datasets

What is the fastest method to count elements of a tensorflow.data.Datset?


Is there a faster way to count all elements of a tensorflow.data.Dataset than

def count_elements(dataset: tf.data.Dataset):
    return dataset.reduce(0, lambda x, _ : x + 1).numpy()

"Faster" means also taking into account memory usage, but execution time is paramount. As far as I can see there is no built-in method for this.


Solution

  • Short answer "No".

    For in-memory datasets there's: tf.data.experimental.cardinality(dataset), but tf.data.Datasets are inherently lazy loaded, and can be infinite, so there's no knowing how many elements there are in a tf.data.Dataset without iterating through it.

    Credit: In TensorFlow 2.0, how can I see the number of elements in a dataset?