pythondictionarytensorflowtensorflow-datasets

Creating a tensorflow dataset that outputs a dict


I have a dict with "metadata" for my dataset, of sort {'m1': array_1, 'm2': array_2, ...}. Each of the arrays has shape (N, ...), where N is the number of samples.

The question: Is it possible to create a tf.data.Dataset that outputs a dictionary {'meta_1': sub_array_1, 'meta_2': sub_array_2, ...} for each iteration of the datasets iterator.get_next()? Here, sub_array_i should contain the ith metadata for one batch, so should have shape (batch_sz, ...).

What I tried so far is using tf.data.Dataset.from_generator(), like this:

N = 100
# dictionary of arrays:
metadata = {'m1': np.zeros(shape=(N,2)), 'm2': np.ones(shape=(N,3,5))} 
num_samples = N

def meta_dict_gen():
    for i in range(num_samples):
        ls = {}
        for key, val in metadata.items():
            ls[key] = val[i]
        yield ls

dataset = tf.data.Dataset.from_generator(meta_dict_gen, output_types=(dict))

The problem with this seems to be in output_types=(dict). The code above throws at me a

TypeError: Expected DataType for argument 'Tout' not < class 'dict'>.


I'm using tensorflow 1.8 and python 3.6.


Solution

  • EDIT: Even though the original question was for TensorFlow 1.8, I updated the answer to TensorFlow 2 (tested on TensorFlow 2.17), which is likely more useful to future readers.

    So actually it is possible to do what you intend, you just have to be specific about the contents of the dict:

    import tensorflow as tf
    import numpy as np
    
    N = 100
    # dictionary of arrays:
    metadata = {'m1': np.zeros(shape=(N,2)), 'm2': np.ones(shape=(N,3,5))}
    num_samples = N
    
    def meta_dict_gen():
        for i in range(num_samples):
            ls = {}
            for key, val in metadata.items():
                ls[key] = val[i]
            yield ls
    
    dataset = tf.data.Dataset.from_generator(
        meta_dict_gen,
        output_signature={
            k: tf.TensorSpec(shape=v.shape[1:], dtype=tf.as_dtype(v.dtype))
            for k, v in metadata.items()})
    print(next(iter(dataset)))
    

    Output:

    {'m1': <tf.Tensor: shape=(2,), dtype=float64, numpy=array([0., 0.])>, 'm2': <tf.Tensor: shape=(3, 5), dtype=float64, numpy=
    array([[1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1.],
           [1., 1., 1., 1., 1.]])>}