pythontensorflowtensorflow-datasets

Split tf tf.data.Dataset tuple into several datasets


I have a tf.data.Dataset with the following shape:

<ConcatenateDataset shapes: ((None, None, 12), (None, 5)), types: (tf.float64, tf.float64)>

Can I split this Dataset to get two datasets looking like this:

<Dataset shapes: (None, None, 12), types: tf.float64>
<Dataset shapes: (None, 5), types: tf.float64>

Solution

  • You can use map function to split them.

    Demo:

    import tensorflow as tf
    
    # Create a random tensorflow dataset.
    dataset1 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([40, 10, 12]), tf.random.uniform([40, 5]))).batch(16)
    dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([40, 12, 12]), tf.random.uniform([40, 5]))).batch(16)
    
    dataset = dataset1.concatenate(dataset2)
    dataset
    >> <ConcatenateDataset shapes: ((None, None, 12), (None, 5)), types: (tf.float32, tf.float32)>
    

    In order to split:

    data = dataset.map(lambda x, y: x)
    labels = dataset.map(lambda x, y: y)