tensorflowkerastransfer-learningquantizationtfmot

How to set training=False for keras-model/layer outside of the __call__ method?


I’m using Keras with tensorflow-model-optimization (tf_mot) for quantization aware training (QAT). My model is based on a pre-trained backbone from keras.application. As mentioned in the transfer learning guide, I must use x = base_model(inputs, training=False). But tf_mot is not working with submodels. The solution mentioned in https://stackoverflow.com/a/72265777/23370406 does not involve using the __call__ method, so I can’t set the training mode to False. What should I do?

The submodel version code (incompatible with tf_mot):

import keras
from keras import applications, layers, models, utils


inp = layers.Input((None, None, 3))
backbone = applications.vgg16.VGG16(include_top=False,
                                    weights=None)
x = backbone(inp, training=False)

backbone.trainable = False

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10, activation='relu')(x)
out = layers.Dense(1, activation='sigmoid')(x)

model = models.Model(inp, out)

model.summary()

The QAT version code (incompatible with disabling training mode):

import keras
from keras import applications, layers, models, utils


inp = layers.Input((None, None, 3))

backbone = applications.vgg16.VGG16(include_top=False,
                                    input_tensor=inp,
                                    weights=None)
x = backbone.output

backbone.trainable = False

x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(10, activation='relu')(x)
out = layers.Dense(1, activation='sigmoid')(x)

model = models.Model(inp, out)

model.summary()

I have read the Karas source code, but have not found a solution compatible with keras>=3.0.0. keras.backend.set_learning_phase was deprecated a few releases ago, unfortunately.

Thanks in advance!


Solution

  • You can use

    x = backbone.call(inp, training=False)
    

    instead of

    x = backbone(inp, training=False)
    

    to get the individual layers into your model and not a submodel. In your example (a bit shorter here) model.summary() would change from

    Model: "model"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     input_2 (InputLayer)        [(None, None, None, 3)]   0         
                                                                     
     vgg16 (Functional)          (None, None, None, 512)   14714688  
                                                                     
     global_average_pooling2d (  (None, 512)               0         
     GlobalAveragePooling2D)                                         
                                                                     
     dense (Dense)               (None, 2)                 1026      
                                                                     
    =================================================================
    Total params: 14715714 (56.14 MB)
    Trainable params: 14715714 (56.14 MB)
    Non-trainable params: 0 (0.00 Byte)
    _________________________________________________________________                                      
    

    to

    Model: "model_1"
    _________________________________________________________________
     Layer (type)                Output Shape              Param #   
    =================================================================
     input_3 (InputLayer)        [(None, None, None, 3)]   0         
                                                                     
     block1_conv1 (Conv2D)       (None, None, None, 64)    1792      
                                                                     
     block1_conv2 (Conv2D)       (None, None, None, 64)    36928     
                                                                     
     block1_pool (MaxPooling2D)  (None, None, None, 64)    0         
                                                                     
     block2_conv1 (Conv2D)       (None, None, None, 128)   73856     
                                                                     
     block2_conv2 (Conv2D)       (None, None, None, 128)   147584    
                                                                     
     block2_pool (MaxPooling2D)  (None, None, None, 128)   0         
                                                                     
     block3_conv1 (Conv2D)       (None, None, None, 256)   295168    
                                                                     
     block3_conv2 (Conv2D)       (None, None, None, 256)   590080    
                                                                     
     block3_conv3 (Conv2D)       (None, None, None, 256)   590080    
                                                                     
     block3_pool (MaxPooling2D)  (None, None, None, 256)   0         
                                                                     
     block4_conv1 (Conv2D)       (None, None, None, 512)   1180160   
                                                                     
     block4_conv2 (Conv2D)       (None, None, None, 512)   2359808   
                                                                     
     block4_conv3 (Conv2D)       (None, None, None, 512)   2359808   
                                                                     
     block4_pool (MaxPooling2D)  (None, None, None, 512)   0         
                                                                     
     block5_conv1 (Conv2D)       (None, None, None, 512)   2359808   
                                                                     
     block5_conv2 (Conv2D)       (None, None, None, 512)   2359808   
                                                                     
     block5_conv3 (Conv2D)       (None, None, None, 512)   2359808   
                                                                     
     block5_pool (MaxPooling2D)  (None, None, None, 512)   0         
                                                                     
     global_average_pooling2d_1  (None, 512)               0         
      (GlobalAveragePooling2D)                                       
                                                                     
     dense_1 (Dense)             (None, 2)                 1026      
                                                                     
    =================================================================
    Total params: 14715714 (56.14 MB)
    Trainable params: 14715714 (56.14 MB)
    Non-trainable params: 0 (0.00 Byte)
    _________________________________________________________________
    


    Edit: Here is the code that is compatible with Keras>=3.x (comes with TF 2.16):

    import keras
    from keras import applications, layers, models, utils
    
    # load resnet here for testing, because resnet has BatchNormalization layers
    backbone = applications.resnet.ResNet50(include_top=False,
                                        weights=None, 
                                        input_shape=(None, None, 3))
    backbone.trainable = False
    
    x = layers.GlobalAveragePooling2D()(backbone.output)
    x = layers.Dense(10, activation='relu')(x)
    out = layers.Dense(1, activation='sigmoid')(x)
    
    model = models.Model(backbone.inputs, out)
    
    # unfreeze all layers except the BatchNormalization layers
    for layer in model.layers:
      if not isinstance(layer, keras.layers.BatchNormalization):
        layer.trainable = True
    

    I'd argue that this is even better than x = backbone(inp, training=False), because a general training=False would also enable inference mode for dropout layers and all other layers that behave differnt in training and inference. If you don't want this, you can add e.g. Dropout to the isinstance test in the for loop.