pythontensorflowkeraspytorchjax

How Can I Use GPU to Accelerate Image Augmentation?


When setting up image augmentation pipelines using keras.layers.Random* or other augmentation or processing methods, we often integrate these pipelines with a data loader, such as the tf.data API, which operates mainly on the CPU. But heavy augmentation operations on the CPU can become a significant bottleneck, as these processes take longer to execute, leaving the GPU underutilized. This inefficiency can impact the overall training performance.

To address this, is it possible to offload augmentation processing to the GPU, enabling faster execution and better resource utilization? If so, how can this be implemented effectively?


Solution

  • We can speed up processing and improve resource usage by offloading data augmentation to the GPU. I'll demonstrate how to do this in keras. Note that the approach might differ slightly depending on the task, such as classification, detection, or segmentation.

    Classification

    Let’s take a classification task as an example. If we use the tf.data API to apply an augmentation pipeline, the processing will run on the CPU. Here's how it can be done.

    import numpy as np
    from keras import layers
    
    a = np.ones((4, 224, 224, 3)).astype(np.float32)
    b = np.ones((4, 2)).astype(np.float32)
    
    augmentation_layers = keras.Sequential(
        [
            layers.RandomFlip("horizontal"),
            layers.RandomRotation(0.1),
            layers.RandomZoom(0.2),
        ]
    )
    
    dataset = tf.data.Dataset.from_tensor_slices((a, b))
    dataset = dataset.batch(3, drop_remainder=True)
    dataset = dataset.map(
        lambda x, y: (augmentation_layers(x), y), 
        num_parallel_calls=tf.data.AUTOTUNE
    )
    x.shape, y.shape
    (TensorShape([3, 224, 224, 3]), TensorShape([3, 2]))
    

    But for heavy augmentation pipelines, it's better to include them inside the model to take advantage of GPU acceleration.

    inputs = keras.Input(shape=(224, 224, 3))
    processed = augmentation_layers(inputs)
    backbone = keras.applications.EfficientNetB0(
        include_top=True, pooling='avg'
    )(processed)
    output = keras.layers.Dense(10)(backbone)
    model = keras.Model(inputs, output)
    model.count_params() / 1e6
    5.340581
    

    Here, we set the augmentation pipeline right after keras.Input. Note that these model-with-augmentations don't affect the target vector. So, for augmentations like cutmix or mixup, this approach won't work. For such cases, I'll explore another solution while testing with a segmentation task.

    Segmentation

    I'll use this dataset for comparing execution times. It's a binary segmentation task. Additionally, I'll run it using keras-3, which might allow for multi-backend support.

    import os
    os.environ["KERAS_BACKEND"] = "tensorflow" # torch, jax
    
    import keras
    from keras import layers
    import tensorflow as tf
    keras.__version__ # 3.4.1
    
    # ref https://keras.io/examples/vision/oxford_pets_image_segmentation/
    # u-net model
    def get_model(img_size, num_classes, classifier_activation):
        ...
        # Add a per-pixel classification layer
        outputs = layers.Conv2D(
            num_classes, 
            3, 
            activation=classifier_activation, 
            padding="same", 
            dtype='float32'
        )(x)
    
        # Define the model
        model = keras.Model(inputs, outputs)
        return model
    
    
    img_size = (224, 224)
    num_classes = 1
    classifier_activation = 'sigmoid'
    model = get_model(
        img_size, 
        num_classes=num_classes, 
        classifier_activation=classifier_activation
    )
    

    Let's define the augmentation pipelines.

    augmentation_layers = [
        layers.RandomFlip("horizontal_and_vertical")
    ]
    
    def augment_data(images, masks):
        combined = tf.concat([images, tf.cast(masks, tf.float32)], axis=-1)
        for layer in augmentation_layers:
            combined = layer(combined)
        images_augmented = combined[..., :3]
        masks_augmented = tf.cast(combined[..., 3:], tf.int32)
        return images_augmented, masks_augmented
    

    Let’s define the tf.data API to build the dataloader. First, I’ll run the model with a dataloader that includes augmentation pipelines. These augmentations will run on the CPU, and I’ll record the execution time.

    def read_image(image_path, mask=False):
        image = tf.io.read_file(image_path)
        
        if mask:
            image = tf.image.decode_png(image, channels=1)
            image.set_shape([None, None, 1])
            image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
            image = tf.cast(image, tf.int32)
        else:
            image = tf.image.decode_png(image, channels=3)
            image.set_shape([None, None, 3])
            image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
            image = image / 255.
            
        return image
    
    def load_data(image_list, mask_list):
        image = read_image(image_list)
        mask  = read_image(mask_list, mask=True)
        return image, mask
    
    def data_generator(image_list, mask_list):
        dataset = tf.data.Dataset.from_tensor_slices((image_list, mask_list))
        dataset = dataset.shuffle(8*BATCH_SIZE) 
        dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    
        # Augmenting on CPU
        dataset = dataset.map(
            augment_data, num_parallel_calls=tf.data.AUTOTUNE
        )
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        return dataset
    
    IMAGE_SIZE = 224
    BATCH_SIZE = 16
    
    train_dataset = data_generator(images, masks)
    print("Train Dataset:", train_dataset)
    Train Dataset: <_PrefetchDataset element_spec=(TensorSpec(shape=(16, 224, 224, 3), dtype=tf.float32, name=None), TensorSpec(shape=(16, 224, 224, 1), dtype=tf.int32, name=None))>
    

    Now, let's compile it and run it.

    optim = keras.optimizers.Adam(0.001)
    bce   = keras.losses.BinaryCrossentropy()
    metrics = ["accuracy"]
    model.compile(
        optimizer=optim, 
        loss=bce, 
        metrics=metrics
    )
    
    %%time
    epochs = 2
    model.fit(
        train_dataset, 
        epochs=epochs, 
    )
    Epoch 1/2
    318/318 ━ 65s 140ms/step - accuracy: 0.9519 - loss: 0.2087
    Epoch 2/2
    318/318 ━ 44s 139ms/step - accuracy: 0.9860 - loss: 0.0338
    CPU times: user 5min 38s, sys: 14.2 s, total: 5min 52s
    Wall time: 1min 48s
    

    Next, we will remove the augmentation layers from the dataloader.

    def data_generator(image_list, mask_list):
        dataset = tf.data.Dataset.from_tensor_slices((image_list, mask_list))
        dataset = dataset.shuffle(8*BATCH_SIZE)
        dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        return dataset
    
    IMAGE_SIZE = 224
    BATCH_SIZE = 16
    
    train_dataset = data_generator(images, masks)
    

    To offload augmentation to the GPU, we’ll create a custom model class, override the train_step, and use the augment_data method that we defined earlier. Here's how to structure it:

    class ExtendedModel(keras.Model):
        def __init__(self, model, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.model = model
    
        def train_step(self, data):
            x, y = data
            x, y = augment_data(x, y)
            return super().train_step((x, y))
    
        def call(self, inputs):
            return self.model(inputs)
    
        def save(
            self, filepath, 
            overwrite=True, 
            include_optimizer=True, 
            save_format=None, 
            add_loss=None, 
        ):
            # Overriding this method will allow us to use the `ModelCheckpoint`
            self.model.save(
                filepath=filepath,
                overwrite=overwrite,
                save_format=save_format,
                include_optimizer=include_optimizer,
            )
    

    Now that we’ve defined the custom model with GPU-accelerated augmentation, let’s compile and run the model. It should be faster compared to using CPU for augmentations.

    model = get_model(
        img_size, 
        num_classes=num_classes, 
        classifier_activation=classifier_activation
    )
    emodel = ExtendedModel(model)
    optim = keras.optimizers.Adam(0.001)
    bce   = keras.losses.BinaryCrossentropy()
    metrics = ["accuracy"]
    emodel.compile(
        optimizer=optim, 
        loss=bce, 
        metrics=metrics
    )
    
    %%time
    epochs = 2
    emodel.fit(
        train_dataset, 
        epochs=epochs, 
        callbacks=[
            keras.callbacks.ModelCheckpoint(
                filepath='model.{epoch:02d}-{loss:.3f}.keras',
                monitor='loss',
                mode='min',
                save_best_only=True
            )
        ]
    )
    Epoch 1/2
    318/318 ━ 54s 111ms/step - accuracy: 0.8885 - loss: 0.2748
    Epoch 2/2
    318/318 ━ 35s 111ms/step - accuracy: 0.9754 - loss: 0.0585
    CPU times: user 4min 43s, sys: 3.81 s, total: 4min 47s
    Wall time: 1min 29s
    

    So, augmentation processing on CPU took total 65+44 = 109 seconds and processing on GPU took total 54+35 = 89 seconds. Around 18.35% improvements.This approach can be applied to object detection tasks as well, where both image manipulation and bounding box adjustments are needed.

    As shown in the ExtendedModel class above, we override the save method, allowing the callbacks.ModelCheckpoint to save the full model. Inference can then be performed as shown below.

    loaded_model = keras.saving.load_model(
        "/kaggle/working/model.02-0.0585.keras"
    )
    x, y = next(iter(train_dataset))
    output = loaded_model.predict(x)
    1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step
    

    Update

    In order to run the above code with multiple backends (i.e., tensorflow, torch, and jax), we need to esnure that the augment_data that is used in ExtendedModel use the following backend agnostic keras.ops functions.

    def augment_data(images, masks):
        combined = keras.ops.concatenate(
            [images, keras.ops.cast(masks, 'float32')], axis=-1
        )
        for layer in augmentation_layers:
            combined = layer(combined)
        images_augmented = combined[..., :3]
        masks_augmented = keras.ops.cast(combined[..., 3:], 'int32')
        return images_augmented, masks_augmented
    

    Additionally, to make the pipeline flexible for all backend, we can update the ExtendedModel as follows. Now, this code can run with tensorflow, jax, and torch backends.

    class ExtendedModel(keras.Model):
        ...
    
        def train_step(self, *args, **kwargs):
            if keras.backend.backend() == "jax":
                return self._jax_train_step(*args, **kwargs)
            elif keras.backend.backend() == "tensorflow":
                return self._tensorflow_train_step(*args, **kwargs)
            elif keras.backend.backend() == "torch":
                return self._torch_train_step(*args, **kwargs)
    
        def _jax_train_step(self, state, data):
            x, y = data
            x, y = augment_data(x, y)
            return super().train_step(state, (x, y))
    
        def _tensorflow_train_step(self, data):
            x, y = data
            x, y = augment_data(x, y)
            return super().train_step((x, y))
    
        def _torch_train_step(self, data):
            x, y = data
            x, y = augment_data(x, y)
            return super().train_step((x, y))
    
        ...