pythonkerasbatch-normalization

Changing BatchNormalization momentum while training in Keras 3


I'm buidling a custom model using Keras (keras 3.3.3 with python 3.9.19) and I would like to increase the momentum of my BatchNormalization layers during training.

Ideally, I'd like to use a custom model with custom layers rather than a Sequential model and make the most of the "standard" training loop (using fit and Callbacks) rather than using a custom training loop.

I now that there is already a solution using Tensorflow but I'd also like to stick to Keras without relying on any backend functionality (e.g., tf.keras).

Now I know that in Tensorflow, one solution would be to use a tf.Variable in my model, like self.bn_momentum = tf.Variable(bn_momentum, trainable=False), pass it to my layers, and update it using a Callback during training.

I imagined it would be the same using Keras. However, when trying to use a keras.Variable I get the following error message:

TypeError: Exception encountered when calling CustomClassifier.call().

float() argument must be a string or a number, not 'Variable'

So I tried getting rid of the keras.Variable and using a simple float instead. The training seems to go just fine with this, but I suspect that nothing happens under the hood.

I was wondering why keras.Variable does not behave like tf.Variable? and is using a simple float variable a good solution in my case? Thanks in advance for your answers!

Here is a minimal reproductible example:

import numpy as np
import keras
from keras import layers


class CustomClassifier(keras.Model):

    def __init__(self, bn_momentum=0.99, **kwargs):
        
        super().__init__(**kwargs)
        self.bn_momentum = bn_momentum # or keras.Variable(bn_momentum, trainable=False)
        self.input_layer = layers.Dense(8, activation="softmax", name="input_layer")
        self.hidden_layer = CustomLayer(16, bn_momentum=self.bn_momentum, name="hidden_layer")      
        self.output_layer = layers.Dense(4, activation="softmax", name="output_scores")

    def call(self, input_points, training=None):
                
        x = self.input_layer(input_points, training=training)
        x = self.hidden_layer(x, training=training)

        return self.output_layer(x)

   
class CustomLayer(layers.Layer):
    
    def __init__(self, units, bn_momentum, **kwargs):
        
        super().__init__(**kwargs)
        self.units = units
        self.bn_momentum = bn_momentum

    def build(self, batch_input_shape):

        self.dense = layers.Dense(self.units, input_shape=batch_input_shape)
        self.bn = layers.BatchNormalization(momentum=self.bn_momentum)
        self.activation = layers.ReLU()

    def call(self, x, training=None):

        x = self.dense(x)
        x = self.bn(x, training=training)

        return self.activation(x)


class BatchNormalizationMomentumScheduler(keras.callbacks.Callback):
    """The decay rate for batch normalization starts with 0.5 and is gradually 
    increased to 0.99."""

    def __init__(self,):
        super().__init__()

        self.initial_momentum = 0.5
        self.final_momentum = 0.99
        self.rate = 0.05

    def on_train_begin(self, logs=None):

        self.model.bn_momentum = self.initial_momentum
        print(f"Initial BatchNormalization momentum is {self.model.bn_momentum:.3f}.")

    def on_epoch_begin(self, epoch, logs=None):

        if epoch:
            new_bn_momentum = self.initial_momentum + self.rate * epoch
            new_bn_momentum = np.min([new_bn_momentum, self.final_momentum])
            self.model.bn_momentum = new_bn_momentum
            print(f"Epoch {epoch}: BatchNormalization momentum is {self.model.bn_momentum:.3f}.")
            

if __name__ =="__main__":

    # Generate random data
    X = np.random.random((1024, 8))
    y = np.random.choice([0, 1, 2, 3], 1024)

    # Instanciate and train model
    model = CustomClassifier()
    model.build((64, 8))
    model.summary()
    model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
    history = model.fit(X, y, epochs=10, callbacks=[BatchNormalizationMomentumScheduler()])

    # Check final
    print("Model momentum after training:", model.bn_momentum)


Solution

  • There are two separate issues there:

    The best way for your problem is probably subclassing the bn layer, with configuring the momentum as a variable:

    class MyBatchNormalization(BatchNormalization):
        def build(self, input_shape):
            super().build(input_shape)
            self.momentum = self.add_weight(
                name='momentum',
                shape=(),
                initializer=keras.initializers.Constant(self.momentum),
                trainable=False,
            )
        
        def get_config(self):
            config = super().get_config()
            config['momentum'] = -1. # Not used, since momentum is a weight now and will be loaded from saved model
            return config
    
    class BNMomentumSched(Callback):
        def __init__(self, sche: Callable[[int], float]):
            super().__init__()
            self.sche = sche
        
        def on_epoch_begin(self, epoch, logs=None):
            for layer in self.model.layers:
                if isinstance(layer, MyBatchNormalization):
                    layer.momentum.assign(self.sche(epoch))
    

    Alternatively, changing the python native number in callback will also work if you explicitly pass jit_compile=False in model.compile(). Though that could make training order of magnitude slower.

    Not directly related to the question, but it is recommended to use keras.ops.xxx instead of tf.xxx for operations to be compatible with other backends.

    ## Update

    \> When running the code the OP provides with the MyBatchNormalization function, it throw AttributeError: 'CustomLayer' object has no attribute 'bn' error

    When a container layer is built, its sublayers are not build as their input shapes are unknown at build time. The Functional and Sequential models are exceptions, and subclassing them do not inherit this trait:

    model = CustomClassifier()
    model.build((64, 8))
    assert model.hidden_layer.built is False
    

    Back to the original problem, the code can be modified to the following, to have variable bn momentum while have working jit compile:

    import numpy as np
    import keras
    from keras import layers
    
    class MyBatchNormalization(keras.layers.BatchNormalization):
        def build(self, input_shape):
            super().build(input_shape)
            self.momentum = self.add_weight(
                name='momentum',
                shape=(),
                initializer=keras.initializers.Constant(self.momentum),
                trainable=False,
            )
    
    class CustomClassifier(keras.Model):
    
        def __init__(self, bn_momentum_init=0.99, **kwargs):
            
            super().__init__(**kwargs)
            self.input_layer = layers.Dense(8, activation="softmax", name="input_layer")
            self.hidden_layer = CustomLayer(16, bn_momentum=bn_momentum_init, name="hidden_layer")      
            self.output_layer = layers.Dense(4, activation="softmax", name="output_scores")
    
        def call(self, input_points, training=None):
                    
            x = self.input_layer(input_points, training=training)
            x = self.hidden_layer(x, training=training)
    
            return self.output_layer(x)
    
        def build(self, input_shape):
            super().build(input_shape)
            self.input_layer.build(input_shape)
            self.hidden_layer.build(input_shape[:-1] + (self.input_layer.units,))
            self.output_layer.build(input_shape[:-1] + (self.hidden_layer.units,))
    
       
    class CustomLayer(layers.Layer):
        
        def __init__(self, units, bn_momentum, **kwargs):
            
            super().__init__(**kwargs)
            self.units = units
            self.bn_momentum = bn_momentum
    
        def build(self, input_shape):
    
            self.dense = layers.Dense(self.units)
            self.bn = MyBatchNormalization(momentum=self.bn_momentum)
            self.activation = layers.ReLU()
            
            self.dense.build(input_shape)
            self.bn.build(input_shape[:-1] + (self.units,))
    
        def call(self, x, training=None):
    
            x = self.dense(x)
            x = self.bn(x, training=training)
    
            return self.activation(x)
    
    
    class BatchNormalizationMomentumScheduler(keras.callbacks.Callback):
        """The decay rate for batch normalization starts with 0.5 and is gradually 
        increased to 0.99."""
    
        def __init__(self,):
            super().__init__()
    
            self.initial_momentum = 0.5
            self.final_momentum = 0.99
            self.rate = 0.05
    
        def on_train_begin(self, logs=None):
    
            self.model.hidden_layer.bn.momentum.assign(self.initial_momentum)
            bn_momentum = float(self.model.hidden_layer.bn_momentum)
            print(f"Initial BatchNormalization momentum is {bn_momentum:.3f}.")
    
        def on_epoch_begin(self, epoch, logs=None):
    
            if epoch:
                new_bn_momentum = self.initial_momentum + self.rate * epoch
                new_bn_momentum = np.min([new_bn_momentum, self.final_momentum])
                self.model.hidden_layer.bn.momentum.assign(new_bn_momentum)
                bn_momentum = float(self.model.hidden_layer.bn.momentum)
                print(f"Epoch {epoch}: BatchNormalization momentum is {bn_momentum:.3f}.")
                
    

    You can check that the model is jitted by checking model.jit_compile (should be True). Variable sharing between layers can be done, but make sure everything happens within the build() block of some upper layer.