testingtensorflowtensorflow-datasets

Limiting the number of items in a tf.data.Dataset


tl;dr; Can I limit the number of elements in a tf.data.Dataset?

A have a training and evaluation loop which processes the entire given dataset. This is not ideal for testing since it takes forever to go through the whole dataset. I can test this code by creating a Mock dataset or by limiting the number of elements of the dataset so the code only goes through, let's say, the first 10 datapoints. How can I do the second one?

Thanks


Solution

  • The simplest way to take only a fixed number of elements n from a Dataset is to use Dataset.take(n). For example:

    large_dataset = ...
    small_dataset = large_dataset.take(10)