pythontensorflowtensorflow-datasetstfrecord

How to write a tfrecord file and read it? The error is: truncated record at 0' failed with Read less bytes than requested [Op:IteratorGetNext]


I want to use tfrecord to deal with heavy MRI images but I don't know how to. Below is my code, the error and data link. (Sorry if you find the code is a bit long).

About the data:

First I rearrange my data,

image_data_path = './drive/MyDrive/Brain Tumour/Task01_BrainTumour/imagesTr/'
label_data_path = './drive/MyDrive/Brain Tumour/Task01_BrainTumour/labelsTr/'

image_paths = [image_data_path + name 
               for name in os.listdir(image_data_path) 
               if not name.startswith(".")]

label_paths = [label_data_path + name
               for name in os.listdir(label_data_path)
               if not name.startswith(".")]

image_paths = sorted(image_paths)
label_paths = sorted(label_paths)

And define a function to load 1 nii file. I use nibabel.

def load_one_sample(image_path, label_path):

  image = nib.load(image_path).get_fdata()
  label = nib.load(label_path).get_fdata().astype(int)  # the original dtype is float64

  return image, label

Here I write some helper functions, 'float' for images and 'int' for labels:

def float_feature(value):
  return tf.train.Feature(float_list = tf.train.FloatList(value = value))

def int64_feature(value):
  return tf.train.Feature(int64_list = tf.train.Int64List(value = value))

def create_example(image_path, label_path):

  image, label = load_one_sample(image_path, label_path)
  image, label = image.ravel(), label.ravel()
  feature = {'image': float_feature(image),
             'label': int64_feature(label)}
  example = tf.train.Example(features = tf.train.Features(feature = feature))

  return example

def parse_tfrecord(example):

  feature = {'image': tf.io.FixedLenFeature([240, 240, 155, 4], tf.float32),
             'label': tf.io.FixedLenFeature([240, 240, 155], tf.int64)}
  parsed_example = tf.io.parse_single_example(example, feature)

  return parsed_example

Then start to convert to and read tfrecord with only one example:

test_writer = tf.io.TFRecordWriter('test.tfrecords')

example = create_example(image_paths[0], label_paths[0])
test_writer.write(example.SerializeToString())

serialised_example = tf.data.TFRecordDataset('test.tfrecords')
parsed_example = serialised_example.map(parse_tfrecord)

Finally I try plotting one image and I got this error message:

for features in parsed_example.take(1):
  plt.imshow(features['image'][:, :, 100, 0])

Error: truncated record at 0' failed with Read less bytes than requested [Op:IteratorGetNext]

Datalink: https://drive.google.com/drive/folders/1HqEgzS8BV2c7xYNrZdEAnrHk7osJJ--2 (task 1 - Brain Tumour)

Where did I go wrong?


Solution

  • This error is occurring because you never call close() after writing the example into a file: Here is a working example with random arrays:

    import tensorflow as tf
    import matplotlib.pyplot as plt
    import numpy as np
    
    def float_feature(value):
      return tf.train.Feature(float_list = tf.train.FloatList(value = value))
    
    def int64_feature(value):
      return tf.train.Feature(int64_list = tf.train.Int64List(value = value))
    
    def create_example():
    
      image, label = np.random.random((16, 16, 155, 4)), np.random.randint(20, size=(16, 16, 155))
      image, label = image.ravel(), label.ravel()
      feature = {'image': float_feature(image),
                 'label': int64_feature(label)}
      example = tf.train.Example(features = tf.train.Features(feature = feature))
      return example
    
    def parse_tfrecord(example):
      feature = {'image': tf.io.FixedLenFeature([16, 16, 155, 4], tf.float32),
                 'label': tf.io.FixedLenFeature([16, 16, 155], tf.int64)}
      parsed_example = tf.io.parse_single_example(example, feature)
    
      return parsed_example
    
    test_writer = tf.io.TFRecordWriter('test.tfrecords')
    
    example = create_example()
    test_writer.write(example.SerializeToString())
    test_writer.close() 
    
    serialised_example = tf.data.TFRecordDataset('test.tfrecords')
    parsed_example = serialised_example.map(parse_tfrecord)
    
    for features in parsed_example.take(1):
      plt.imshow(features['image'][:, :, 100, 0])
    

    enter image description here