I'm working with a large dataset saved in a non-standard format. I can write a pure python data-reader, but when called from DL dataloaders, like tf.data.Dataset
, it takes forever to stream over the data once.
Tensorflow's tf.data.Dataset.save
is a life saver. This runs through the dataset once using the slow pure-python data loader, and resaves it into a much faster format, and a faster dataloader can be created using the associated load
method. Great.
I want to save the large dataset to multiple shards. The save
method mentioned above allows for this with the shard_func
argument, but the documentation only provides an example for a single shard. In theory multiple shards can be created based on elements of the datasets, but I don't see an obvious way to do this.
The output of shard_func
must be a np.int64
, but the inputs to shard_func
are symbolic tensors that cannot be transformed to numpy
arrays during execution. So how is it possible to output anything other than a single np.int64
in a deterministic way based on dataset elements?
Heres a simple example. I want to save the dataset [0,...,99]
into 10 shards. Here's an attempt that fails:
import numpy as np
import tensorflow as tf
dataset = tf.data.Dataset.range(100)
def custom_shard_function(x):
"""
Transforms x into a np.int64 between 0,..,9
"""
return np.int64(x) % 10
dataset.save('/tmp/saved_dataset',shard_func=custom_shard_function)
with the error
NotImplementedError: Cannot convert a symbolic tf.Tensor (args_0:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported.
I understand this isn't possible, but I also have no idea how to get a dynamic np.int64
out of a symbolic tensor in a deterministic way.
I'm using TF 2.12.
Simple solution. Don't use numpy. If you take that out the tensorflow modulus function will be used.
This worked for me withe TF 2.10.1
import tensorflow as tf
dataset = tf.data.Dataset.range(100)
def custom_shard_function(x):
"""
Transforms x into a np.int64 between 0,..,9
"""
return x % 10
dataset.save('/tmp/saved_dataset',shard_func=custom_shard_function)