I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?
Edit:
My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.
You may use Dataset.take()
and Dataset.skip()
:
train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)
full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.
Take:
Creates a Dataset with at most count elements from this dataset.
Skip:
Creates a Dataset that skips count elements from this dataset.
You may also want to look into Dataset.shard()
:
Creates a Dataset that includes only 1/num_shards of this dataset.