Suppose I have a custom layer which computes the loss for me, using external trainable variables using TF 2.4 (and yes, I know it's a silly example and loss, it is just for reproducibility, the actual loss is much more complex):
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Layer, Input
from tensorflow.keras import Model
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf
n_col = 10
n_row = 1000
X = np.random.normal(size=(n_row, n_col))
beta = np.arange(10)
y = X @ beta
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
class MyLoss(Layer):
def __init__(self, var1, var2):
super(MyLoss, self).__init__()
self.var1 = tf.Variable(var1)
self.var2 = tf.Variable(var2)
def get_vars(self):
return self.var1, self.var2
def custom_loss(self, y_true, y_pred):
return self.var1 ** 2 * tf.math.reduce_mean(tf.math.square(y_true-y_pred)) + self.var2 ** 2
def call(self, y_true, y_pred):
self.add_loss(self.custom_loss(y_true, y_pred))
return y_pred
inputs = Input(shape=(X_train.shape[1],))
y_input = Input(shape=(1,))
hidden1 = Dense(10)(inputs)
output = Dense(1)(hidden1)
my_loss = MyLoss(0.5, 0.5)(y_input, output) # here can also initialize those var1, var2
model = Model(inputs=[inputs, y_input], outputs=my_loss)
model.compile(optimizer= 'adam')
Training this model is simple:
history = model.fit([X_train, y_train], None,
batch_size=32, epochs=100, validation_split=0.1, verbose=0,
callbacks=[EarlyStopping(monitor='val_loss', patience=5)])
And if we write a custom Callback or train epoch by epoch we can see how var1
and var2
converge to 0 as would be expected:
var1_list = []
var2_list = []
for i in range(100):
if i % 10 == 0:
print('step %d' % i)
model.fit([X_train, y_train], None,
batch_size=32, epochs=1, validation_split=0.1, verbose=0)
var1, var2 = model.layers[-1].get_vars()
var1_list.append(var1.numpy())
var2_list.append(var2.numpy())
plt.plot(var1_list, label='var1')
plt.plot(var2_list, 'r', label='var2')
plt.legend()
plt.show()
Short question: how do I make the model stop (EarlyStopping
with some patience
) according to the convergence of var1
and var2
(i.e. their vector size, self.var1**2 + self.var2**2
, and again assume the loss is much more complex and you cannot just add this vector size to the loss)?
Longer question: (if you have the time/patience)
Metric
and make EarlyStopping
track it?EarlyStopping
focus on "convergence" when all its got is mode
"min" or "max"? (I wonder could we extend EarlyStopping
instead of extending Callback
)EarlyStopping
to pay attention to both, i.e. "stop if you don't see improvement in loss AND improvement in convergence for patience=10"?Well at least for the "shorter question" this turned out quite simple, following this example from TF docs, implementing EarlyStopping
with the twist of focusing on the variables norm:
class EarlyStoppingAtVarsConvergence(tf.keras.callbacks.Callback):
def __init__(self, norm_thresh=0.01, patience=0):
super(EarlyStoppingAtVarsConvergence, self).__init__()
self.norm_thresh = norm_thresh
self.patience = patience
def on_train_begin(self, logs=None):
# The number of epoch it has waited when norm hasn't converged.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize sigmas norm.
self.vars_norm = self.get_vars_norm()
def get_vars_norm(self):
var1, var2 = model.layers[-1].get_vars()
return var1**2 + var2**2
def on_epoch_end(self, epoch, logs=None):
current_norm = self.get_vars_norm()
if np.abs(current_norm - self.vars_norm) > self.norm_thresh:
self.sigmas_norm = current_norm
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
Then the model would be run with:
history = model.fit([X_train, y_train], None,
batch_size=32, epochs=100, validation_split=0.1, verbose=0,
callbacks=[EarlyStoppingAtVarsConvergence(patience=5)])