I am training a convolutional network and I want to stop training once the validation error hits 90%. I thought about using EarlyStopping and setting baseline to .90 but then it stops training whenever the validation accuracy is below that baseline for given number of epochs(which is just 0 here). So my code is:
es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])
When I use this code my training stops after the first epoch with given results:
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 - 7s - loss: 0.4600 - acc: 0.8330 - val_loss: 0.3426 - val_acc: 0.8787
What else can I try to stop my training once the validation accuracy hits 90% or above?
Here is the rest of the code:
tf.keras.layers.Conv2D(64, (3,3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(152, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer=Adam(learning_rate=0.001),loss='sparse_categorical_crossentropy', metrics=['accuracy'])
es=EarlyStopping(monitor='val_acc',mode='auto',verbose=1,baseline=.90,patience=0)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2,callbacks=[es])
Early Stopping Callback will search for a value that stopped increasing (or decreasing) so it's not a good use for your problem. However tf.keras
allows you to use custom callbacks.
For your example:
class MyThresholdCallback(tf.keras.callbacks.Callback):
def __init__(self, threshold):
super(MyThresholdCallback, self).__init__()
self.threshold = threshold
def on_epoch_end(self, epoch, logs=None):
val_acc = logs["val_acc"]
if val_acc >= self.threshold:
self.model.stop_training = True
For TF version 2.3 or above, you might have to use "val_accuracy"
instead of "val_acc"
. Thank you Christian Westbrook for the note in the comments.
The above Callback, on each epoch end, will extract Validation Accuracy from all available logs. Then it will compare it with user defined threshold (in your case 90%). If the criterion is met the training will be stopped.
With that you can simply call:
my_callback = MyThresholdCallback(threshold=0.9)
history = model.fit(training_images, training_labels, validation_data=(test_images, test_labels), epochs=30, verbose=2, callbacks=[my_callback])
Alternatively, you can use def on_batch_end(...)
if you want to stop immediately. This however, requires parameters batch, logs
instead of epoch, logs
.