tensorflowtensorflow-datasets

How do I split Tensorflow datasets?


I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?

Edit:

My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.


Solution

  • You may use Dataset.take() and Dataset.skip():

    train_size = int(0.7 * DATASET_SIZE)
    val_size = int(0.15 * DATASET_SIZE)
    test_size = int(0.15 * DATASET_SIZE)
    
    full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
    full_dataset = full_dataset.shuffle()
    train_dataset = full_dataset.take(train_size)
    test_dataset = full_dataset.skip(train_size)
    val_dataset = test_dataset.skip(test_size)
    test_dataset = test_dataset.take(test_size)
    

    For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

    Take:

    Creates a Dataset with at most count elements from this dataset.

    Skip:

    Creates a Dataset that skips count elements from this dataset.

    You may also want to look into Dataset.shard():

    Creates a Dataset that includes only 1/num_shards of this dataset.