pythontensorflowtensorflow2.0tf.keraskeras-2

TensorFlow, Keras: Replace Activation layer in pretrained model


I'm trying to replace swish activation with relu activation in pretrained TF model EfficientNetB0. EfficientNetB0 uses swish activation in Conv2D and Activation layers. This SO post is very similar to what I'm looking for. I also found an answer which works for models without skip connection. Below is the code:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import ReLU

def replace_swish_with_relu(model):
    '''
    Modify passed model by replacing swish activation with relu
    '''
    for layer in tuple(model.layers):
        layer_type = type(layer).__name__
        if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
            print(layer_type, layer.activation.__name__)
            if layer_type == "Conv2D":
                # conv layer with swish activation.
                # Do something
                layer.activation = ReLU() # This didn't work
            else:
                # activation layer
                # Do something
                layer = tf.keras.layers.Activation('relu', name=layer.name + "_relu") # This didn't work
    return model

# load pretrained efficientNet
model = tf.keras.applications.EfficientNetB0(
    include_top=True, weights='imagenet', input_tensor=None,
    input_shape=(224, 224, 3), pooling=None, classes=1000,
    classifier_activation='softmax')

# convert swish activation to relu activation
model = replace_swish_with_relu(model)
model.save("efficientNet-relu")

How to modify replace_swish_with_relu to replace swish activations with relu in the passed model?

Thank you for any pointers/help.


Solution

  • layer.activation points to tf.keras.activations.swish function address. We can modify it to point to tf.keras.activations.relu. Below is the modified, replace_swish_with_relu:

    def replace_swish_with_relu(model):
        '''
        Modify passed model by replacing swish activation with relu
        '''
        for layer in tuple(model.layers):
            layer_type = type(layer).__name__
            if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
                print(layer_type, layer.activation.__name__)
                if layer_type == "Conv2D":
                    # conv layer with swish activation
                    layer.activation = tf.keras.activations.relu
                else:
                    # activation layer
                    layer.activation = tf.keras.activations.relu
        return model
    

    Note: If you are modifying the activation function, then you need to retrain the model to work with the new activation. Related.