I'm working on a test with tensorflow. I have my dataset into two folders. I configured the batch_size
, height
and width
for the train_data
but then i can't see them with matplotlib
or use it in the model.
#Import dataset
import pathlib
import os
data_dir = pathlib.Path(r'C:\Users\vion1\Ele\Engie\Exercices\DL\Pikachu\dataset')
image_count = len(list(data_dir.glob('*/*')))
print(image_count)
#374
batch_size = 32
img_height = 256
img_width = 256
train_data = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=42,
image_size=(img_height, img_width),
batch_size=batch_size,
)
class_names = train_data.class_names
print(train_data)
#Found 374 files belonging to 2 classes.
#Using 300 files for training.
#<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>
plt.figure(figsize=(10, 10))
for images, labels in train_data.take(1):
for i in range(3):
ax = plt.subplot(1, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
The error is :
InvalidArgumentError: Unknown image file format. One of JPEG, PNG, GIF, BMP required.
[[{{node decode_image/DecodeImage}}]] [Op:IteratorGetNext]
I think that train_date.take(1)
doesn't take the file but i can't understand why and how to fix it, any idea?
The code which you have mentioned looks proper, the main reason for failure could be as per the error is that one or more file in your tf.data.Dataset
does not belong to any of the mentioned file extension.
To check the corrupted file you can refer the below code.
Here I'm taking the example dataset mentioned in the document
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
from tensorflow import keras
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
roses = list(data_dir.glob('roses/*'))
Now, leet's check the unique filenames in the roses directory.
file_names = [str(i) for i in roses]
unique_files = set(i.split('.')[-1] for i in file_names)
print(unique_files)
Output:
{'jpg'}
In the output directory if you get any filetypes other than allowed filetypes, you need to recheck your data. Else you can follow this document for same procedure.