I am creating a tensorflow Dataset
using the from_generator
function. In graph/session mode, it works fine:
import tensorflow as tf
x = {str(i): i for i in range(10)}
def gen():
for i in x:
yield x[i]
ds = tf.data.Dataset.from_generator(gen, tf.int32)
batch = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
print(sess.run(batch), end=' ')
except tf.errors.OutOfRangeError:
# 0 1 2 3 4 5 6 7 8 9
Suprisingly however, it fails using eager execution:
import tensorflow as tf
x = {str(i): i for i in range(10)}
def gen():
for i in x:
yield x[i]
ds = tf.data.Dataset.from_generator(gen, tf.int32)
for x in ds:
print(x, end=' ')
# TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got '1'
I was assuming that, since the body of the generator is pure python that does not get serialized, tensorflow would not look into -- indeed not care what is in -- the generator. But this is apparently not the case. So why does tensorflow care about what's inside the generator? Assuming the generator cannot be changed, is there a way to somehow work around this problem?
tl;dr The issue is unrelated to TensorFlow. Your loop variable shadows previously defined x
Fact 1: for
loop in Python does not have a namespace and leaks loop variables into the surrounding namespace (globals()
in your example).
Fact 2: Closures are "dynamic" i.e. the gen
generator only knows it should lookup the name "x"
to evaluate x[i]
. The actual value of x
will be resolved when the generator is iterated over.
Putting these two together and unrolling the first two iterations of the for
loop we get the following execution sequence:
ds = tf.data.Dataset.from_generator(gen, tf.int32)
it = iter(ds)
x = next(it) # Calls to the generator which yields back x[i].
print(x, end='')
# Calls to the generator as before, but x is no longer a dict so x[i]
# is actually indexing into a Tensor!
x = next(it)
The fix is simple: use a different loop variable name.
for item in ds:
print(item, end=' ')