pythontensorflowpaddingtensorflow-datasetsbatching

Batching and padding using the Tensorflow data API


I'm having trouble understanding how the TensorFlow data API (tensorflow.data.Dataset) works. My input is a list of lists of integers that I want to batch, pad and concatenate. E.g my data looks like this

data = [[1, 2, 3, 4, 5, 6, 7],
        [1, 2, 3, 4],
        [1]]

with batch size 3 it should become:

[[[1, 2, 3], [4, 5, 6], [7, 0, 0]],
 [[1, 2, 3], [4, 0, 0]],
 [[1, 0, 0]]]

and finally:

[[1, 2, 3], [4, 5, 6], [7, 0, 0],
 [1, 2, 3], [4, 0, 0], [1, 0, 0]]

Solution

  • It wasn't easy, but I finally got it to work:

    def batch_each(x):
        return Dataset.from_tensor_slices(x).batch(3)
    data = [[1, 2, 3, 4, 5, 6, 7],
            [1, 2, 3, 4],
            [1]]
    rt = tf.ragged.constant(data)
    ds = Dataset \
        .from_tensor_slices(rt) \
        .flat_map(batch_each) \
        .padded_batch(1, padded_shapes = (3,)) \
        .unbatch()
    for e in ds:
        print(e)