tensorflowmachine-learningimage-augmentation

How does data augmentation work with preprocessing function that is called with map method


In a tutorial I found this code to do data augmentation:

def preprocess_with_augmentation(image, label):
   resized_image = tf.image.resize(image, [224, 224])

  # data augmentation with Tensorflow
    augmented_image = tf.image.random_flip_left_right(resized_image)
    augmented_image = tf.image.random_hue(augmented_image, 0.10)
    augmented_image = tf.image.random_brightness(augmented_image, 0.06)
    augmented_image = tf.image.random_contrast(augmented_image, 0.65, 1.35)

  # run Xceptions preprocessing function
    preprocessed_image = tf.keras.applications.xception.preprocess_input(augmented_image)

    print("Working on next ")
    return preprocessed_image, label

This function is used as follows:

train_data = tfds.load('tf_flowers', split="train[:80%]", as_supervised=True)
test_data  = tfds.load('tf_flowers', split="train[80%:100%]", as_supervised=True)
x_augmented_train = train_data.map(preprocess_with_augmentation).batch(32).prefetch(1)
...
history = augmentation_model.fit(x_augmented_train,epochs=10, validation_data=test_data)

How does this create augmented data sets? My conjecture was that iterating over the data set multiple times applies the preprocessing function multiple times and in this way creates a new augmented dataset for each epoch. This would require that the preprocessing function is called again and again and that the random augmentations are independent.

Since I was not sure about how it works, I added the print-statement to the preprocess_with_augmentation function. However, if I call the programm, "Working on next" is printed only once at the very beginning and not during the different epochs.

If my conjecture was right I should be printed many times.

I thought maybe output is surpressed during the call to the fit function, so I changed the print to incrementing a counter, which no help. The counter shows only 1 call.

The next thing I tried was: I created an iterator using it = iter(x_augmented_train) and created a loop to print every 100th image. But if I create another iterator it2 = iter(x_augmented_train) I had hoped that the images would look different, since the augmentation should not be the same in each epoch. But the images were identical, so I am wondering how this method works. Maybe it does not work?


Solution

  • When you use map() on a tf.data, it is executed in Graph mode. Therefore python's print statements will only print once, you need to change it into tf.print to see the actual process.

    See more details: https://www.tensorflow.org/guide/intro_to_graphs#using_tffunction