tensorflowtensorflow-datasets

Interleaving multiple TensorFlow datasets together


The current TensorFlow dataset interleave functionality is basically a interleaved flat-map taking as input a single dataset. Given the current API, what's the best way to interleave multiple datasets together? Say they have already been constructed and I have a list of them. I want to produce elements from them alternatively and I want to support lists with more than 2 datasets (i.e., stacked zips and interleaves would be pretty ugly).

Thanks! :)

@mrry might be able to help.


Solution

  • See also:


    Even though this is not "clean", it is the only workaround I came up with.

    datasets = [tf.data.Dataset...]
    
    def concat_datasets(datasets):
        ds0 = tf.data.Dataset.from_tensors(datasets[0])
        for ds1 in datasets[1:]:
            ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
        return ds0
    
    ds = tf.data.Dataset.zip(tuple(datasets)).flat_map(
        lambda *args: concat_datasets(args)
    )