pythontensorflowkerastensorflow-datasets

Windowing a TensorFlow dataset without losing cardinality information?


tf.data.Dataset.window returns a new dataset, whose elements are datasets, and elements of those nested datasets are windows of the desired size. If you have a dataset (say, Dataset.range(10) and want a dataset of windows like [0 1 2] [1 2 3] ... [7 8 9]), there's a trick to do that with window plus flat_map:

>>> d = tf.data.Dataset.range(10).window(3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(3))
>>> print(list(d))
[<tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 1, 2])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 3, 4])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 4, 5])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([4, 5, 6])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([5, 6, 7])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([6, 7, 8])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([7, 8, 9])>]

However, the flat_map causes the dataset to lose cardinality information:

>>> d.cardinality.numpy()
<tf.Tensor: shape=(), dtype=int64, numpy=-2>

(-2 is UNKNOWN_CARDINALITY; see Tensorflow 2.0: flat_map() to flatten Dataset of Dataset returns cardinality -2)

I would like to create a dataset of such windows, while retaining the cardinality information. One slight annoyance from working with datasets of unknown cardinality is that Keras training progress bars need to run on one epoch first before they can produce an ETA. I tried .take(n_windows) where I calculate n_windows myself, but that still returned a dataset with UNKNOWN_CARDINALITY.

Is there some way to window a dataset without losing cardinality information?


Solution

  • The main issue is that cardinality is computed statically. Therefore the cardinality of a flat_map operation can not be computed. You can refer to this issue

    The solution, as you know the relation of the flat_map inputs and output, is to set the cardinality yourself using tf.data.experimental.assert_cardinality.

    This is an example on how to set back the window cardinality:

    import tensorflow as tf
    
    ds = tf.data.Dataset.range(10)
    print("Original cardinality -> ", ds.cardinality().numpy())
    # Output:
    # Original cardinality -> 10
    
    ds = ds.window(3, shift=1, drop_remainder=True)
    # cardinality at this point is still known.
    # as drop_remainder is true, window cardinality will be <= original cardinality
    window_cardinality = ds.cardinality()
    print("window cardinality ->",window_cardinality.numpy())
    # Output:
    # window cardinality -> 8
    
    ds = ds.flat_map(lambda x: x.batch(3))
    # after flat_map the inferred cardinality is lost.
    print("flat cardinality ->",ds.cardinality().numpy())
    # Output:
    # flat cardinality -> -2
    
    # as we know the flat_map relation is 1:1 we can set the cardinality back to the original value.
    ds = ds.apply(tf.data.experimental.assert_cardinality(window_cardinality))
    print("dataset cardinality ->",ds.cardinality().numpy())
    print("length of dataset ->", len(list(ds)))
    # Output: 
    # dataset cardinality -> 8
    # length of dataset -> 8
    
    for idx, x in ds.enumerate():
        print(f"{idx} -> {x}")
    # Output:
    # 0 -> [0 1 2]
    # 1 -> [1 2 3]
    # 2 -> [2 3 4]
    # 3 -> [3 4 5]
    # 4 -> [4 5 6]
    # 5 -> [5 6 7]
    # 6 -> [6 7 8]
    # 7 -> [7 8 9]