I'm having trouble understanding how the TensorFlow data API (tensorflow.data.Dataset) works. My input is a list of lists of integers that I want to batch, pad and concatenate. E.g my data looks like this
data = [[1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4],
[1]]
with batch size 3 it should become:
[[[1, 2, 3], [4, 5, 6], [7, 0, 0]],
[[1, 2, 3], [4, 0, 0]],
[[1, 0, 0]]]
and finally:
[[1, 2, 3], [4, 5, 6], [7, 0, 0],
[1, 2, 3], [4, 0, 0], [1, 0, 0]]
It wasn't easy, but I finally got it to work:
def batch_each(x):
return Dataset.from_tensor_slices(x).batch(3)
data = [[1, 2, 3, 4, 5, 6, 7],
[1, 2, 3, 4],
[1]]
rt = tf.ragged.constant(data)
ds = Dataset \
.from_tensor_slices(rt) \
.flat_map(batch_each) \
.padded_batch(1, padded_shapes = (3,)) \
.unbatch()
for e in ds:
print(e)