Is there a faster way to count all elements of a tensorflow.data.Dataset
than
def count_elements(dataset: tf.data.Dataset):
return dataset.reduce(0, lambda x, _ : x + 1).numpy()
"Faster" means also taking into account memory usage, but execution time is paramount. As far as I can see there is no built-in method for this.
Short answer "No".
For in-memory datasets there's: tf.data.experimental.cardinality(dataset)
, but tf.data.Dataset
s are inherently lazy loaded, and can be infinite, so there's no knowing how many elements there are in a tf.data.Dataset
without iterating through it.
Credit: In TensorFlow 2.0, how can I see the number of elements in a dataset?