tf.kerastf.data.dataset

Passing Argument to a Generator to build a tf.data.Dataset


I am trying to build a tensorflow dataset from a generator. I have a list of tuples called some_list , where each tuple has an integer and some text.

When I do not pass some_list as an argument to the generator, the code works fine

import tensorflow as tf
import random
import numpy as np

some_list=[(1,'One'),[2,'Two'],[3,'Three'],[4,'Four'],
      (5,'Five'),[6,'Six'],[7,'Seven'],[8,'Eight']]

def text_gen1():
    random.shuffle(some_list)
    size=len(some_list)
    i=0
    while True:
    
        yield some_list[i][0],some_list[i][1]
        i+=1
        if i>size:
            i=0
            random.shuffle(some_list)
#Not passing any argument
tf_dataset1 = tf.data.Dataset.from_generator(text_gen1,output_types=(tf.int32,tf.string),
                                        output_shapes = ((),()))
for count_batch in tf_dataset1.repeat().batch(3).take(2):
    print(count_batch)

(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([7, 1, 2])>, <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'Seven', b'One', b'Two'], dtype=object)>) (<tf.Tensor: shape=(3,), dtype=int32, numpy=array([3, 5, 4])>, <tf.Tensor: shape=(3,), dtype=string, numpy=array([b'Three', b'Five', b'Four'], dtype=object)>)

However, when I try to pass some_list as an argument, the code fails

def text_gen2(file_list):
random.shuffle(file_list)
size=len(file_list)
i=0
while True:
    
    yield file_list[i][0],file_list[i][1]
    i+=1
    if i>size:
        i=0
        random.shuffle(file_list)

tf_dataset2 = tf.data.Dataset.from_generator(text_gen2,args=[some_list],output_types= 
(tf.int32,tf.string),output_shapes = ((),()))
for count_batch in tf_dataset1.repeat().batch(3).take(2):
    print(count_batch)

ValueError: Can't convert Python sequence with mixed types to Tensor.

I noticed , when I try to pass a list of integers as an argument , the code works. However, a list of tuples seems to make it crash. Can someone shed some light on it ?


Solution

  • The problem is what it says is, you cannot have heterogeneous data types (int and str) in the same tf.Tensor. I did a few changes and came up with the code below.

    import tensorflow as tf
    import random
    import numpy as np
    
    some_list=[(1,'One'),[2,'Two'],[3,'Three'],[4,'Four'],
          (5,'Five'),[6,'Six'],[7,'Seven'],[8,'Eight']]
    
    def text_gen2(int_list, str_list):
      
      for x, y in zip(int_list, str_list):      
          yield x, y
      
    
    tf_dataset2 = tf.data.Dataset.from_generator(
        text_gen2,
        args=list(zip(*some_list)),
        output_types=(tf.int32,tf.string),output_shapes = ((),())
    )
    
    i = 0
    for count_batch in tf_dataset2.repeat().batch(4).shuffle(buffer_size=6):
        print(count_batch)
        i += 1
        if i > 10: break;