I would like to remove the first N layers from the pretrained Keras model. For example, an EfficientNetB0
, whose first 3 layers are responsible only for preprocessing:
import tensorflow as tf
efinet = tf.keras.applications.EfficientNetB0(weights=None, include_top=True)
print(efinet.layers[:3])
# [<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7fa9a870e4d0>,
# <tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling at 0x7fa9a61343d0>,
# <tensorflow.python.keras.layers.preprocessing.normalization.Normalization at 0x7fa9a60d21d0>]
As M.Innat mentioned, the first layer is an Input Layer
, which should be either spared or re-attached. I would like to remove those layers, but simple approach like this throws error:
cut_input_model = return tf.keras.Model(
inputs=[efinet.layers[3].input],
outputs=efinet.outputs
)
This will result in:
ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(...)
What would be the recommended way to do this?
The reason for getting the Graph disconnected
error is because you don't specify the Input
layer. But that's not the main issue here. Sometimes removing the intermediate layer from the keras
model is not straightforward with Sequential
and Functional
API.
For sequential, it comparatively should be easy whereas, in a functional model, you need to care about multi-input blocks (e.g multiply
, add
etc). For example: if you want to remove some intermediate layer in a sequential model, you can easily adapt this solution. But for the functional model (efficientnet
), you can't because of the multi-input internal blocks and you will encounter this error: ValueError: A merged layer should be called on a list of inputs
. So that needs a bit more work AFAIK, here is a possible approach to overcome it.
Here I will show a simple workaround for your case, but it's probably not general and also unsafe in some cases. That based on this approach; using pop
method. Why it can be unsafe to use!. Okay, let's first load the model.
func_model = tf.keras.applications.EfficientNetB0()
for i, l in enumerate(func_model.layers):
print(l.name, l.output_shape)
if i == 8: break
input_19 [(None, 224, 224, 3)]
rescaling_13 (None, 224, 224, 3)
normalization_13 (None, 224, 224, 3)
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)
Next, using .pop
method:
func_model._layers.pop(1) # remove rescaling
func_model._layers.pop(1) # remove normalization
for i, l in enumerate(func_model.layers):
print(l.name, l.output_shape)
if i == 8: break
input_22 [(None, 224, 224, 3)]
stem_conv_pad (None, 225, 225, 3)
stem_conv (None, 112, 112, 32)
stem_bn (None, 112, 112, 32)
stem_activation (None, 112, 112, 32)
block1a_dwconv (None, 112, 112, 32)
block1a_bn (None, 112, 112, 32)
block1a_activation (None, 112, 112, 32)
block1a_se_squeeze (None, 32)