My question is about how to get batch inputs from multiple (or sharded) tfrecords. I've read the example The basic pipeline is, take the training set as as example, (1) first generate a series of tfrecords (e.g., train-000-of-005
, train-001-of-005
, ...), (2) from these filenames, generate a list and fed them into the tf.train.string_input_producer
to get a queue, (3) simultaneously generate a tf.RandomShuffleQueue
to do other stuff, (4) using tf.train.batch_join
to generate batch inputs.
I think this is complex, and I'm not sure the logic of this procedure. In my case, I have a list of .npy
files, and I want to generate sharded tfrecords(multiple seperated tfrecords, not just one single large file). Each of these .npy
files contains different number of positive and negative samples (2 classes). A basic method is to generate one single large tfrecord file. But the file is too large (~20Gb
). So I resort to sharded tfrecords. Are there any simpler way to do this?
The whole process is simplied using the Dataset API
. Here are both the parts: (1): Convert numpy array to tfrecords
and (2): read the tfrecords to generate batches
Example arrays:
inputs = np.random.normal(size=(5, 32, 32, 3))
labels = np.random.randint(0,2,size=(5,))
def npy_to_tfrecords(inputs, labels, filename):
with as writer:
for X, y in zip(inputs, labels):
# Feature contains a map of string to feature proto objects
feature = {}
feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=X.flatten()))
feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))
# Construct the Example proto object
example = tf.train.Example(features=tf.train.Features(feature=feature))
# Serialize the example to a string
serialized = example.SerializeToString()
# write the serialized objec to the disk
npy_to_tfrecords(inputs, labels, 'numpy.tfrecord')
filenames = ['numpy.tfrecord']
dataset =
# for version 1.5 and above use
# example proto decode
def _parse_function(example_proto):
keys_to_features = {'X', 32, 3), dtype=tf.float32),
'y':, tf.int64, default_value=0)}
parsed_features =, keys_to_features)
return parsed_features['X'], parsed_features['y']
# Parse the record into tensors.
dataset =
# Generate batches
dataset = dataset.batch(5)
Check the generated batches are proper:
for data in dataset:
np.testing.assert_allclose(inputs[0] ,data[0][0])
np.testing.assert_allclose(labels[0] ,data[1][0])