I am trying to create a custom dataset in TFRecords
for a CycleGAN model. The model requires a new type of dataset which is not available so I need to create one. I have a few JPG images of 256x256. Following this link, I created TFrecords file for my images, below code:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# images input
def convert_to(images, output_directory, name):
num_examples = images.shape[0]
rows = images.shape[1]
cols = images.shape[2]
depth = 1
filename = os.path.join(output_directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image = images[index]
image_raw = images[index].tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
def read_image(file_name, images_path):
image = skimage.io.imread(images_path + file_name)
return image
def get_name(img_name):
remove_ext = img_name.split(".")[0]
name = remove_ext.split("_")
return name[0]
images_path = "data/train/"
image_list = os.listdir(images_path)
images = []
for img_name in tqdm(image_list):
tfrec_name = get_name(img_name)
print(tfrec_name)
img_data = read_image(img_name, images_path)
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
Once the TFRecords are written, I use below code to read and decode it back
PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))
IMAGE_SIZE = [256, 256]
def decode_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
image = tf.reshape(image, [*IMAGE_SIZE, 3])
return image
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example['image_raw'])
return image
def load_dataset(filenames, labeled=True, ordered=False):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
return dataset
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
example_photo = next(iter(photo_ds))
The decoding does not work since I get the below error in the last line
InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with ']LBhXKeQFVC4S=/1'
[[{{node DecodeJpeg}}]]
Clearly there is a mismatch between how I am writing the TFRecord in convert_to
function and how I am reading it back in read_tfrecord
function. But I am not sure how to fix it. Any suggestion?
EDIT
@sebastian-sz solution solves the problem. I tried to display one image like below
import matplotlib.pyplot as plt
plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0])
It displays the image but I see that the color/light of the image is much darker than the original image. Not sure what is going on though. Attached screenshot. Original image at the bottom.
There are few issues in your code:
The problem is in a function convert_to
, in more detail, the function expects a list of images:
(...)
image = images[index]
(...)
However, you are passing a single image
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
Hence, later the image
is of shape (for example) 224, 3
which is an invalid image shape.
To fix this, change convert_to
to accept a single image.
Skimage .tobytes
seems to be incompatible. Consider using tf.io.encode_jpeg(image).numpy()
to obtain image bytes.
I was able to save and read sample image with the following code:
# Saving
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# images input
def convert_to(image, output_directory, name):
rows = image.shape[0]
cols = image.shape[1]
depth = 1
filename = os.path.join(output_directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.compat.v1.python_io.TFRecordWriter(filename)
print(image.shape)
image_raw = tf.io.encode_jpeg(image).numpy()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
def read_image(file_name, images_path):
image = skimage.io.imread(images_path + file_name)
return image
def get_name(img_name):
remove_ext = img_name.split(".")[0]
name = remove_ext.split("_")
return name[0]
images_path = "data/train/"
image_list = os.listdir(images_path)
for img_name in tqdm(image_list):
tfrec_name = get_name(img_name)
print(tfrec_name)
img_data = read_image(img_name, images_path)
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
# Loading:
PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))
IMAGE_SIZE = [256, 256]
def decode_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
# Changed this from reshape
# Consider reshape if all your images have the same shape
image = tf.image.resize(image, IMAGE_SIZE)
return image
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example['image_raw'])
return image
def load_dataset(filenames, labeled=True, ordered=False):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
return dataset
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
example_photo = next(iter(photo_ds))