I'm trying to replace swish activation with relu activation in pretrained TF model EfficientNetB0. EfficientNetB0 uses swish activation in Conv2D and Activation layers. This SO post is very similar to what I'm looking for. I also found an answer which works for models without skip connection. Below is the code:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import ReLU
def replace_swish_with_relu(model):
'''
Modify passed model by replacing swish activation with relu
'''
for layer in tuple(model.layers):
layer_type = type(layer).__name__
if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
print(layer_type, layer.activation.__name__)
if layer_type == "Conv2D":
# conv layer with swish activation.
# Do something
layer.activation = ReLU() # This didn't work
else:
# activation layer
# Do something
layer = tf.keras.layers.Activation('relu', name=layer.name + "_relu") # This didn't work
return model
# load pretrained efficientNet
model = tf.keras.applications.EfficientNetB0(
include_top=True, weights='imagenet', input_tensor=None,
input_shape=(224, 224, 3), pooling=None, classes=1000,
classifier_activation='softmax')
# convert swish activation to relu activation
model = replace_swish_with_relu(model)
model.save("efficientNet-relu")
How to modify replace_swish_with_relu
to replace swish activations with relu in the passed model?
Thank you for any pointers/help.
layer.activation
points to tf.keras.activations.swish
function address. We can modify it to point to tf.keras.activations.relu
. Below is the modified, replace_swish_with_relu
:
def replace_swish_with_relu(model):
'''
Modify passed model by replacing swish activation with relu
'''
for layer in tuple(model.layers):
layer_type = type(layer).__name__
if hasattr(layer, 'activation') and layer.activation.__name__ == 'swish':
print(layer_type, layer.activation.__name__)
if layer_type == "Conv2D":
# conv layer with swish activation
layer.activation = tf.keras.activations.relu
else:
# activation layer
layer.activation = tf.keras.activations.relu
return model
Note: If you are modifying the activation function, then you need to retrain the model to work with the new activation. Related.