tensorflowtfrecord

How do you parse a TFRecord example from a byte-string to a dictionary of tensors?


I am training a multi-task transformer for a project and would like to switch my data structure over to TFRecords because my training is bottle-necked by on-the-fly data generation. I am currently structuring a single sample of data as a dictionary of tensors, like this:

{'continuous_input': tf.Tensor(), 'categorical_input': tf.Tensor(), 'continuous_output': tf.Tensor(), 'categorical_output': tf.Tensor()}

Within a sample, these 4 tensors have the same length, but between samples, these tensors vary in length. The two continuous_ tensors are tf.float32, whereas the two categorical_ tensors are tf.int32. More explicit details of these tensors are in the code below.

I think that I've successfully written my data to TFRecords in the correct format (byte-strings).

Problem statement: I am unable to figure out how to read these TFRecords back into memory and parse the byte-strings into the dictionary of tensors structure above. I include a fully reproducible example of my issue below, which uses Numpy v1.23.4 and Tensorflow v2.10.0. It creates fake data with the aforementioned dictionary structure, saves TFRecords to your working directory, reloads these TFRecords and attempts to parse them with my function parse_tfrecord_fn(). I know that the issue lies in parse_tfrecord_fn() but I do not know the appropriate tf.io tool to resolve this.

Reproducible example:

import os
import os.path as op
import numpy as np
import tensorflow as tf


# Helper functions for writing TFRecords
def _tensor_feature(value):
    serialized_nonscalar = tf.io.serialize_tensor(value)
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[serialized_nonscalar.numpy()]))


def create_example(sample):
    feature = {
        "continuous_input": _tensor_feature(sample['continuous_input']),
        "categorical_input": _tensor_feature(sample['categorical_input']),
        "continuous_output": _tensor_feature(sample['continuous_output']),
        "categorical_output": _tensor_feature(sample['categorical_output']),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()


# Helper functions for reading/preparing TFRecord data

def parse_tfrecord_fn(example):
    feature_description = {
        "continuous_input": tf.io.VarLenFeature(tf.string),
        "categorical_input": tf.io.VarLenFeature(tf.string),
        "continuous_output": tf.io.VarLenFeature(tf.string),
        "categorical_output": tf.io.VarLenFeature(tf.string)
    }
    example = tf.io.parse_single_example(example, feature_description)
    # TODO: WHAT GOES HERE?
    return example


def get_dataset(filenames, batch_size):
    dataset = (
        tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.AUTOTUNE)
            .map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
            .shuffle(batch_size * 10)
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE)
    )
    return dataset

# Make fake data
num_samples_per_tfrecord = 100
num_train_samples = 1600
num_tfrecords = num_train_samples // num_samples_per_tfrecord
fake_sequence_lengths = np.random.randint(3, 35, num_train_samples)
fake_data = []
for i in range(num_train_samples):
    seq_len = fake_sequence_lengths[i]
    fake_data.append({'continuous_input': tf.random.uniform([seq_len], minval=0, maxval=1, dtype=tf.float32),
                      'categorical_input': tf.random.uniform([seq_len], minval=0, maxval=530, dtype=tf.int32),
                      'continuous_output': tf.fill(seq_len, -1.0),
                      'categorical_output': tf.fill(seq_len, -1)})

tfrecords_dir = './tfrecords'
if not op.exists(tfrecords_dir):
    os.makedirs(tfrecords_dir)  # create TFRecords output folder

# Write fake data to tfrecord files
for tfrec_num in range(num_tfrecords):
    samples = fake_data[(tfrec_num * num_samples_per_tfrecord): ((tfrec_num + 1) * num_samples_per_tfrecord)]
    with tf.io.TFRecordWriter(tfrecords_dir + "/file_%.2i.tfrec" % tfrec_num) as writer:
        for sample in samples:
            example = create_example(sample)
            writer.write(example)

# (Try to) Load all the TFRecord data into a (parsed) tf dataset
train_filenames = tf.io.gfile.glob(f"{tfrecords_dir}/*.tfrec")

# Problem: the line below doesn't return the original tensors of fake_data, because my parse_tfrecord_fn is wrong
# Question: What must I add to parse_tfrecord_fn to give this the desired behavior?
dataset = get_dataset(train_filenames, batch_size=32)

# For ease of debugging parse_tfrecord_fn():
dataset = tf.data.TFRecordDataset(train_filenames, num_parallel_reads=tf.data.AUTOTUNE)
element = dataset.take(1).get_single_element()
parse_tfrecord_fn(element)  # set your breakpoint here, then can step through parse_tfrecord_fn()

The function parse_tfrecord_fn() accepts a byte-string as input, which looks like this:

example = "b'\n\xb4\x03\nj\n\x10continuous_input\x12V\nT\nR\x08\x01\x12\x04\x12\x02\x08\x12"H..."

The command example = tf.io.parse_single_example(example, feature_description), where the arguments are defined as in the my reproducible example, returns a dictionary of SparseTensors with the desired 4 keys ('continuous_input', 'categorical_input', etc.). However, the values of these SparseTensors are either absent or inaccessible to me, so I cannot extract them and parse them, such as with tf.io.parse_tensor(example['continuous_input'].values.numpy().tolist()[0], out_type=tf.float32).


Solution

  • I solved the issue and my initial suspicion was correct – it was a simple change needed in the parser function, parse_tfrecord_fn. I include the fully working code below, for anyone this may help going forward. I made a minor modification to the helper functions for writing the TFRecords simply to match common design patterns. The substantive change was in parse_tfrecord_fn.

    Key insights:

    1. Use tf.io.FixedLenFeature([], tf.string) when parsing any tfrecord objects that were serialized originally into a bytes_list. The intuition here is that, although the length of a bytes_list string may vary from object to object, it is still just 1 string, and that "1" is what makes it a fixed length feature.

    2. Undo the bytes_list serialization of a tensor with tf.io.parse_tensor(), specifying the tensor's original dtype with the out_type argument.

    Combining these two insights, the proper flow is as follows:

    import os
    import os.path as op
    import numpy as np
    import tensorflow as tf
    
    
    # Helper functions for writing TFRecords
    def _bytes_feature(value):
        """Returns a bytes_list from a string / byte."""
        # If the value is an eager tensor BytesList won't unpack a string from an EagerTensor.
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    
    def create_example(sample):
        feature = {
            "continuous_input": _bytes_feature(tf.io.serialize_tensor(sample['continuous_input'])),
            "categorical_input": _bytes_feature(tf.io.serialize_tensor(sample['categorical_input'])),
            "continuous_output": _bytes_feature(tf.io.serialize_tensor(sample['continuous_output'])),
            "categorical_output": _bytes_feature(tf.io.serialize_tensor(sample['categorical_output'])),
        }
    
        return tf.train.Example(features=tf.train.Features(feature=feature)).SerializeToString()
    
    
    # Helper functions for reading/preparing TFRecord data
    def parse_tfrecord_fn(example_to_parse):
        feature_description = {
            "continuous_input": tf.io.FixedLenFeature([], tf.string),
            "categorical_input": tf.io.FixedLenFeature([], tf.string),
            "continuous_output": tf.io.FixedLenFeature([], tf.string),
            "categorical_output": tf.io.FixedLenFeature([], tf.string)
        }
        parsed_example = tf.io.parse_single_example(example_to_parse, feature_description)
        return {'continuous_input': tf.io.parse_tensor(parsed_example['continuous_input'], out_type=tf.float32),
                'categorical_input': tf.io.parse_tensor(parsed_example['categorical_input'], out_type=tf.int32),
                'continuous_output': tf.io.parse_tensor(parsed_example['continuous_output'], out_type=tf.float32),
                'categorical_output': tf.io.parse_tensor(parsed_example['categorical_output'], out_type=tf.int32)}
    
    
    def get_dataset(filenames, batch_size):
        dataset = (
            tf.data.TFRecordDataset(filenames, num_parallel_reads=tf.data.AUTOTUNE)
                .map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
                .shuffle(batch_size * 10)
                .padded_batch(batch_size=batch_size,
                              padding_values={'categorical_input': 0, 'continuous_input': 0.0,
                                              'categorical_output': -1,
                                              'continuous_output': -1.0},
                              padded_shapes={'categorical_input': [None], 'continuous_input': [None],
                                             'categorical_output': [None],
                                             'continuous_output': [None]},
                              drop_remainder=True)
                .prefetch(tf.data.AUTOTUNE)
        )
        return dataset
    
    
    # Make fake data
    num_samples_per_tfrecord = 100
    num_train_samples = 1600
    num_tfrecords = num_train_samples // num_samples_per_tfrecord
    fake_sequence_lengths = np.random.randint(3, 35, num_train_samples)
    fake_data = []
    for i in range(num_train_samples):
        seq_len = fake_sequence_lengths[i]
        fake_data.append({"continuous_input": tf.random.uniform([seq_len], minval=0, maxval=1, dtype=tf.float32),
                          "categorical_input": tf.random.uniform([seq_len], minval=0, maxval=530, dtype=tf.int32),
                          "continuous_output": tf.fill(seq_len, -1.0),
                          "categorical_output": tf.fill(seq_len, -1)})
    
    tfrecords_dir = './tfrecords'
    if not op.exists(tfrecords_dir):
        os.makedirs(tfrecords_dir)  # create TFRecords output folder
    
    # Write fake data to tfrecord files
    for tfrec_num in range(num_tfrecords):
        samples = fake_data[(tfrec_num * num_samples_per_tfrecord): ((tfrec_num + 1) * num_samples_per_tfrecord)]
        with tf.io.TFRecordWriter(tfrecords_dir + "/file_%.2i.tfrec" % tfrec_num) as writer:
            for sample in samples:
                example = create_example(sample)
                writer.write(example)
    
    # Load all the TFRecord data into a (parsed) tf dataset
    train_filenames = tf.io.gfile.glob(f"{tfrecords_dir}/*.tfrec")
    
    # The line below works now!
    dataset = get_dataset(train_filenames, batch_size=32)
    
    for el in dataset:
        successful_element = el
        break
    
    print(successful_element)