pythontensorflowtensorflow-datasetsdata-preprocessing

What tensorflow's flat_map + window.batch() does to a dataset/array?


I'm following one of the online courses about time series predictions using Tensorflow. The function used to convert Numpy array (TS) into a Tensorflow dataset used is LSTM-based model is already given (with my comment lines):

def windowed_dataset(series, window_size, batch_size, shuffle_buffer):
     # creating a tensor from an array
     dataset = tf.data.Dataset.from_tensor_slices(series)
     # cutting the tensor into fixed-size windows
     dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)  
     # joining windows into a batch?
     dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
     # separating row into features/label
     dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1], window[-1]))
     dataset = dataset.batch(batch_size).prefetch(1)
     return dataset

This code work fine but I want to understand it better to modify/adapt it for my needs.

If I remove dataset.flat_map(lambda window: window.batch(window_size + 1)) operation, I receive the TypeError: '_VariantDataset' object is not subscriptable pointing to the line: lambda window: (window[:-1], window[-1]))

I managed to rewrite part of this code (skipping shuffling) to Numpy-based one:

def windowed_dataset_np(series, window_size):
    values = sliding_window_view(series, window_size)
    X = values[:, :-1]
    X = tf.convert_to_tensor(np.expand_dims(X, axis=-1))
    y = values[:,-1]
    return X, y

Syntax of fitting of the model looks a bit differently but it works fine.

My two questions are:

  1. What does dataset.flat_map(lambda window: window.batch(window_size + 1)) achieves?
  2. Is the second code really equivalent to the three first operations in the original function?

Solution

  • I would break down the operations into smaller parts to really understand what is happening, since applying window to a dataset actually creates a dataset of windowed datasets containing tensor sequences:

    import tensorflow as tf
    
    window_size = 2
    dataset = tf.data.Dataset.range(7)
    dataset = dataset.window(window_size + 1, shift=1, drop_remainder=True)  
    
    for i, window in enumerate(dataset):
      print('{}. windowed dataset'.format(i + 1))
      for w in window:
        print(w)
    
    1. windowed dataset
    tf.Tensor(0, shape=(), dtype=int64)
    tf.Tensor(1, shape=(), dtype=int64)
    tf.Tensor(2, shape=(), dtype=int64)
    2. windowed dataset
    tf.Tensor(1, shape=(), dtype=int64)
    tf.Tensor(2, shape=(), dtype=int64)
    tf.Tensor(3, shape=(), dtype=int64)
    3. windowed dataset
    tf.Tensor(2, shape=(), dtype=int64)
    tf.Tensor(3, shape=(), dtype=int64)
    tf.Tensor(4, shape=(), dtype=int64)
    4. windowed dataset
    tf.Tensor(3, shape=(), dtype=int64)
    tf.Tensor(4, shape=(), dtype=int64)
    tf.Tensor(5, shape=(), dtype=int64)
    5. windowed dataset
    tf.Tensor(4, shape=(), dtype=int64)
    tf.Tensor(5, shape=(), dtype=int64)
    tf.Tensor(6, shape=(), dtype=int64)
    

    Notice how the window is always shifted by one position due to the parameter shift=1. Now, the operation flat_map is used here to flatten the dataset of datasets into a dataset of elements; however, you still want to keep the windowed sequences you created so you divide the dataset according to the window parameters using dataset.batch:

    dataset = dataset.flat_map(lambda window: window.batch(window_size + 1))
    for w in dataset:
      print(w)
    
    tf.Tensor([0 1 2], shape=(3,), dtype=int64)
    tf.Tensor([1 2 3], shape=(3,), dtype=int64)
    tf.Tensor([2 3 4], shape=(3,), dtype=int64)
    tf.Tensor([3 4 5], shape=(3,), dtype=int64)
    tf.Tensor([4 5 6], shape=(3,), dtype=int64)
    

    You could also first flatten the dataset of datasets and then apply batch if you want to create the windowed sequences:

    dataset = dataset.flat_map(lambda window: window).batch(window_size + 1)
    

    Or only flatten the dataset of datasets:

    dataset = dataset.flat_map(lambda window: window)
    for w in dataset:
      print(w)
    
    tf.Tensor(0, shape=(), dtype=int64)
    tf.Tensor(1, shape=(), dtype=int64)
    tf.Tensor(2, shape=(), dtype=int64)
    tf.Tensor(1, shape=(), dtype=int64)
    tf.Tensor(2, shape=(), dtype=int64)
    tf.Tensor(3, shape=(), dtype=int64)
    tf.Tensor(2, shape=(), dtype=int64)
    tf.Tensor(3, shape=(), dtype=int64)
    tf.Tensor(4, shape=(), dtype=int64)
    tf.Tensor(3, shape=(), dtype=int64)
    tf.Tensor(4, shape=(), dtype=int64)
    tf.Tensor(5, shape=(), dtype=int64)
    tf.Tensor(4, shape=(), dtype=int64)
    tf.Tensor(5, shape=(), dtype=int64)
    tf.Tensor(6, shape=(), dtype=int64)
    

    But that is probably not what you want. Regarding this line in your question: dataset = dataset.shuffle(shuffle_buffer).map(lambda window: (window[:-1], window[-1])), it is pretty trivial. It simply splits the data into sequences and labels, using the last element of each sequence as the label:

    dataset = dataset.shuffle(2).map(lambda window: (window[:-1], window[-1]))
    for w in dataset:
      print(w)
    
    (<tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 2])>, <tf.Tensor: shape=(), dtype=int64, numpy=3>)
    (<tf.Tensor: shape=(2,), dtype=int64, numpy=array([2, 3])>, <tf.Tensor: shape=(), dtype=int64, numpy=4>)
    (<tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 4])>, <tf.Tensor: shape=(), dtype=int64, numpy=5>)
    (<tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 5])>, <tf.Tensor: shape=(), dtype=int64, numpy=6>)
    (<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1])>, <tf.Tensor: shape=(), dtype=int64, numpy=2>)