pythonpython-3.xtensorflowtensorflow-datasets

tf.data.Dataset: how to get the dataset size (number of elements in an epoch)?


Let's say I have defined a dataset in this way:

filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))

how can I get the number of elements that are inside the dataset (hence, the number of single elements that compose an epoch)?

I know that tf.data.Dataset already knows the dimension of the dataset, because the repeat() method allows repeating the input pipeline for a specified number of epochs. So it must be a way to get this information.


Solution

  • tf.data.Dataset.list_files creates a tensor called MatchingFiles:0 (with the appropriate prefix if applicable).

    You could evaluate

    tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
    

    to get the number of files.

    Of course, this would work in simple cases only, and in particular if you have only one sample (or a known number of samples) per image.

    In more complex situations, e.g. when you do not know the number of samples in each file, you can only observe the number of samples as an epoch ends.

    To do this, you can watch the number of epochs that is counted by your Dataset. repeat() creates a member called _count, that counts the number of epochs. By observing it during your iterations, you can spot when it changes and compute your dataset size from there.

    This counter may be buried in the hierarchy of Datasets that is created when calling member functions successively, so we have to dig it out like this.

    d = my_dataset
    # RepeatDataset seems not to be exposed -- this is a possible workaround 
    RepeatDataset = type(tf.data.Dataset().repeat())
    try:
      while not isinstance(d, RepeatDataset):
        d = d._input_dataset
    except AttributeError:
      warnings.warn('no epoch counter found')
      epoch_counter = None
    else:
      epoch_counter = d._count
    

    Note that with this technique, the computation of your dataset size is not exact, because the batch during which epoch_counter is incremented typically mixes samples from two successive epochs. So this computation is precise up to your batch length.