pythontensorflowkerasdeep-learningvgg-net

Connecting BatchDataset with Keras VGG16 preprocess_input


I am using tf.keras.preprocessing.image_dataset_from_directory to get a BatchDataset, where the dataset has 10 classes.

I am trying to integrate this BatchDataset with a Keras VGG16 (docs) network. From the docs:

Note: each Keras Application expects a specific kind of input preprocessing. For VGG16, call tf.keras.applications.vgg16.preprocess_input on your inputs before passing them to the model.

However, I am struggling to get this preprocess_input working with a BatchDataset. Can you please help me figure out how to connect these two dots?

Please see the below code:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(train_data_dir, image_size=(224, 224))
train_ds = tf.keras.applications.vgg16.preprocess_input(train_ds)

This will throw TypeError: 'BatchDataset' object is not subscriptable:

Traceback (most recent call last):
  ...
  File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
    return imagenet_utils.preprocess_input(
  File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
    return _preprocess_symbolic_input(
  File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
    x = x[..., ::-1]
TypeError: 'BatchDataset' object is not subscriptable

From TypeError: 'DatasetV1Adapter' object is not subscriptable (from BatchDataset not subscriptable when trying to format Python dictionary as table) the suggestion was to use:

train_ds = tf.keras.applications.vgg16.preprocess_input(
    list(train_ds.as_numpy_iterator())
)

However, this also fails:

Traceback (most recent call last):
  ...
  File "/path/to/venv/lib/python3.10/site-packages/keras/applications/vgg16.py", line 232, in preprocess_input
    return imagenet_utils.preprocess_input(
  File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 117, in preprocess_input
    return _preprocess_symbolic_input(
  File "/path/to/venv/lib/python3.10/site-packages/keras/applications/imagenet_utils.py", line 278, in _preprocess_symbolic_input
    x = x[..., ::-1]
TypeError: list indices must be integers or slices, not tuple

This is all using Python==3.10.3 with tensorflow==2.8.0.

How can I get this working? Thank you in advance.


Solution

  • Okay I figured it out. I needed to pass a tf.Tensor, not a tf.data.Dataset. One can get a Tensor out by iterating over the Dataset.

    This can be done in a few ways:

    train_ds = tf.keras.preprocessing.image_dataset_from_directory(...)
    
    # Option 1
    batch_images = next(iter(train_ds))[0]
    preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)
    
    # Option 2:
    for batch_images, batch_labels in train_ds:
        preprocessed_images = tf.keras.applications.vgg16.preprocess_input(batch_images)
    

    If you convert option 2 into a generator, it can be directly passed into the downstream model.fit. Cheers!