I am learning how to use Tensorflow with a piece of old code provided at a workshop a couple of years back (i.e. it should've been tested and works). However it didn't work, and my investigations led me back to step 1 of loading the data and making sure it has loaded properly.
The data reading pipeline is as follows:
@tf.function
def load(path_pair):
image_path = path_pair[0]
masks_path = path_pair[1]
image_raw = tf.io.read_file(image_path)
image = tf.io.decode_image(
image_raw, channels=1, dtype=tf.uint8
)
masks_raw = tf.io.read_file(masks_path)
masks = tf.io.decode_image(
masks_raw, channels=NUM_CONTOURS, dtype=tf.uint8
)
return image / 255, masks / 255```
2. The function used to create the dataset
```def create_datasets(dataset_type):
path_pairs = get_path_pairs(dataset_type) # this just gives a list of 2 x 2 tuples containing the image/mask path to load
dataset = tf.data.Dataset.from_tensor_slices(path_pairs)
dataset = dataset.shuffle(
len(path_pairs),
reshuffle_each_iteration=True,
)
dataset = dataset.map(load)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset```
When I use the create_datasets function on a dataset that contains 818 data pairs and check the size of the loaded dataset using len(dataset)it tells me there is only 2 items loaded.
The problem is, that you are batching your dataset, thus when you use len(dataset)
, you get the number of batches, not the number of elements in your dataset.
To get them you can, for instance, iterate over your batches:
num_samples = 0
for batch in dataset:
num_samples += len(batch[0])
print(num_samples)