tensorflowdataloader

How to use the `shard_func` in tensorflow's `tf.data.Dataset.save`


Background:

I'm working with a large dataset saved in a non-standard format. I can write a pure python data-reader, but when called from DL dataloaders, like tf.data.Dataset, it takes forever to stream over the data once.

Tensorflow's tf.data.Dataset.save is a life saver. This runs through the dataset once using the slow pure-python data loader, and resaves it into a much faster format, and a faster dataloader can be created using the associated load method. Great.

Question

I want to save the large dataset to multiple shards. The save method mentioned above allows for this with the shard_func argument, but the documentation only provides an example for a single shard. In theory multiple shards can be created based on elements of the datasets, but I don't see an obvious way to do this.

The output of shard_func must be a np.int64, but the inputs to shard_func are symbolic tensors that cannot be transformed to numpy arrays during execution. So how is it possible to output anything other than a single np.int64 in a deterministic way based on dataset elements?

Heres a simple example. I want to save the dataset [0,...,99] into 10 shards. Here's an attempt that fails:

import numpy as np
import tensorflow as tf
dataset = tf.data.Dataset.range(100)

def custom_shard_function(x):
    """
    Transforms x into a np.int64 between 0,..,9
    """
    return np.int64(x) % 10

dataset.save('/tmp/saved_dataset',shard_func=custom_shard_function)

with the error

NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.

I understand this isn't possible, but I also have no idea how to get a dynamic np.int64 out of a symbolic tensor in a deterministic way.

I'm using TF 2.12.


Solution

  • Simple solution. Don't use numpy. If you take that out the tensorflow modulus function will be used.

    This worked for me withe TF 2.10.1

    import tensorflow as tf
    dataset = tf.data.Dataset.range(100)
    
    def custom_shard_function(x):
        """
        Transforms x into a np.int64 between 0,..,9
        """
        return x % 10
    
    dataset.save('/tmp/saved_dataset',shard_func=custom_shard_function)