I'm buidling a custom model using Keras (keras 3.3.3 with python 3.9.19) and I would like to increase the momentum of my BatchNormalization
layers during training.
Ideally, I'd like to use a custom model with custom layers rather than a Sequential model and make the most of the "standard" training loop (using fit
and Callbacks
) rather than using a custom training loop.
I now that there is already a solution using Tensorflow but I'd also like to stick to Keras without relying on any backend functionality (e.g., tf.keras
).
Now I know that in Tensorflow, one solution would be to use a tf.Variable
in my model, like self.bn_momentum = tf.Variable(bn_momentum, trainable=False)
, pass it to my layers, and update it using a Callback
during training.
I imagined it would be the same using Keras. However, when trying to use a keras.Variable
I get the following error message:
TypeError: Exception encountered when calling CustomClassifier.call().
float() argument must be a string or a number, not 'Variable'
So I tried getting rid of the keras.Variable
and using a simple float instead. The training seems to go just fine with this, but I suspect that nothing happens under the hood.
I was wondering why keras.Variable
does not behave like tf.Variable
? and is using a simple float variable a good solution in my case? Thanks in advance for your answers!
Here is a minimal reproductible example:
import numpy as np
import keras
from keras import layers
class CustomClassifier(keras.Model):
def __init__(self, bn_momentum=0.99, **kwargs):
super().__init__(**kwargs)
self.bn_momentum = bn_momentum # or keras.Variable(bn_momentum, trainable=False)
self.input_layer = layers.Dense(8, activation="softmax", name="input_layer")
self.hidden_layer = CustomLayer(16, bn_momentum=self.bn_momentum, name="hidden_layer")
self.output_layer = layers.Dense(4, activation="softmax", name="output_scores")
def call(self, input_points, training=None):
x = self.input_layer(input_points, training=training)
x = self.hidden_layer(x, training=training)
return self.output_layer(x)
class CustomLayer(layers.Layer):
def __init__(self, units, bn_momentum, **kwargs):
super().__init__(**kwargs)
self.units = units
self.bn_momentum = bn_momentum
def build(self, batch_input_shape):
self.dense = layers.Dense(self.units, input_shape=batch_input_shape)
self.bn = layers.BatchNormalization(momentum=self.bn_momentum)
self.activation = layers.ReLU()
def call(self, x, training=None):
x = self.dense(x)
x = self.bn(x, training=training)
return self.activation(x)
class BatchNormalizationMomentumScheduler(keras.callbacks.Callback):
"""The decay rate for batch normalization starts with 0.5 and is gradually
increased to 0.99."""
def __init__(self,):
super().__init__()
self.initial_momentum = 0.5
self.final_momentum = 0.99
self.rate = 0.05
def on_train_begin(self, logs=None):
self.model.bn_momentum = self.initial_momentum
print(f"Initial BatchNormalization momentum is {self.model.bn_momentum:.3f}.")
def on_epoch_begin(self, epoch, logs=None):
if epoch:
new_bn_momentum = self.initial_momentum + self.rate * epoch
new_bn_momentum = np.min([new_bn_momentum, self.final_momentum])
self.model.bn_momentum = new_bn_momentum
print(f"Epoch {epoch}: BatchNormalization momentum is {self.model.bn_momentum:.3f}.")
if __name__ =="__main__":
# Generate random data
X = np.random.random((1024, 8))
y = np.random.choice([0, 1, 2, 3], 1024)
# Instanciate and train model
model = CustomClassifier()
model.build((64, 8))
model.summary()
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
history = model.fit(X, y, epochs=10, callbacks=[BatchNormalizationMomentumScheduler()])
# Check final
print("Model momentum after training:", model.bn_momentum)
There are two separate issues there:
In keras v3, when you are using the tf or jax backend, automatic jit is enabled. In this case, anything in native python are assumed to be constants after tracing, and changing their values during the training loop won't make a difference.
In keras, variables are supposed to be instantiated in the build() method. When instantiated in other place, you may encounter the "variable not in the same graph" error.
The best way for your problem is probably subclassing the bn layer, with configuring the momentum as a variable:
class MyBatchNormalization(BatchNormalization):
def build(self, input_shape):
super().build(input_shape)
self.momentum = self.add_weight(
name='momentum',
shape=(),
initializer=keras.initializers.Constant(self.momentum),
trainable=False,
)
def get_config(self):
config = super().get_config()
config['momentum'] = -1. # Not used, since momentum is a weight now and will be loaded from saved model
return config
class BNMomentumSched(Callback):
def __init__(self, sche: Callable[[int], float]):
super().__init__()
self.sche = sche
def on_epoch_begin(self, epoch, logs=None):
for layer in self.model.layers:
if isinstance(layer, MyBatchNormalization):
layer.momentum.assign(self.sche(epoch))
Alternatively, changing the python native number in callback will also work if you explicitly pass jit_compile=False
in model.compile()
. Though that could make training order of magnitude slower.
Not directly related to the question, but it is recommended to use keras.ops.xxx
instead of tf.xxx
for operations to be compatible with other backends.
## Update
\> When running the code the OP provides with the MyBatchNormalization function, it throw AttributeError: 'CustomLayer' object has no attribute 'bn'
error
When a container layer is built, its sublayers are not build as their input shapes are unknown at build time. The Functional and Sequential models are exceptions, and subclassing them do not inherit this trait:
model = CustomClassifier()
model.build((64, 8))
assert model.hidden_layer.built is False
Back to the original problem, the code can be modified to the following, to have variable bn momentum while have working jit compile:
import numpy as np
import keras
from keras import layers
class MyBatchNormalization(keras.layers.BatchNormalization):
def build(self, input_shape):
super().build(input_shape)
self.momentum = self.add_weight(
name='momentum',
shape=(),
initializer=keras.initializers.Constant(self.momentum),
trainable=False,
)
class CustomClassifier(keras.Model):
def __init__(self, bn_momentum_init=0.99, **kwargs):
super().__init__(**kwargs)
self.input_layer = layers.Dense(8, activation="softmax", name="input_layer")
self.hidden_layer = CustomLayer(16, bn_momentum=bn_momentum_init, name="hidden_layer")
self.output_layer = layers.Dense(4, activation="softmax", name="output_scores")
def call(self, input_points, training=None):
x = self.input_layer(input_points, training=training)
x = self.hidden_layer(x, training=training)
return self.output_layer(x)
def build(self, input_shape):
super().build(input_shape)
self.input_layer.build(input_shape)
self.hidden_layer.build(input_shape[:-1] + (self.input_layer.units,))
self.output_layer.build(input_shape[:-1] + (self.hidden_layer.units,))
class CustomLayer(layers.Layer):
def __init__(self, units, bn_momentum, **kwargs):
super().__init__(**kwargs)
self.units = units
self.bn_momentum = bn_momentum
def build(self, input_shape):
self.dense = layers.Dense(self.units)
self.bn = MyBatchNormalization(momentum=self.bn_momentum)
self.activation = layers.ReLU()
self.dense.build(input_shape)
self.bn.build(input_shape[:-1] + (self.units,))
def call(self, x, training=None):
x = self.dense(x)
x = self.bn(x, training=training)
return self.activation(x)
class BatchNormalizationMomentumScheduler(keras.callbacks.Callback):
"""The decay rate for batch normalization starts with 0.5 and is gradually
increased to 0.99."""
def __init__(self,):
super().__init__()
self.initial_momentum = 0.5
self.final_momentum = 0.99
self.rate = 0.05
def on_train_begin(self, logs=None):
self.model.hidden_layer.bn.momentum.assign(self.initial_momentum)
bn_momentum = float(self.model.hidden_layer.bn_momentum)
print(f"Initial BatchNormalization momentum is {bn_momentum:.3f}.")
def on_epoch_begin(self, epoch, logs=None):
if epoch:
new_bn_momentum = self.initial_momentum + self.rate * epoch
new_bn_momentum = np.min([new_bn_momentum, self.final_momentum])
self.model.hidden_layer.bn.momentum.assign(new_bn_momentum)
bn_momentum = float(self.model.hidden_layer.bn.momentum)
print(f"Epoch {epoch}: BatchNormalization momentum is {bn_momentum:.3f}.")
You can check that the model is jitted by checking model.jit_compile
(should be True).
Variable sharing between layers can be done, but make sure everything happens within the build() block of some upper layer.