pythonmachine-learningneural-networkconv-neural-networkkeras

How to tell Keras stop training based on loss value?


Currently I use the following code:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

It tells Keras to stop training when loss didn't improve for 2 epochs. But I want to stop training after loss became smaller than some constant "THR":

if val_loss < THR:
    break

I've seen in documentation there are possibility to make your own callback: http://keras.io/callbacks/ But nothing found how to stop training process. I need an advice.


Solution

  • I found the answer. I looked into Keras sources and find out code for EarlyStopping. I made my own callback, based on it:

    class EarlyStoppingByLossVal(Callback):
        def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
            super(Callback, self).__init__()
            self.monitor = monitor
            self.value = value
            self.verbose = verbose
    
        def on_epoch_end(self, epoch, logs={}):
            current = logs.get(self.monitor)
            if current is None:
                warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)
    
            if current < self.value:
                if self.verbose > 0:
                    print("Epoch %05d: early stopping THR" % epoch)
                self.model.stop_training = True
    

    And usage:

    callbacks = [
        EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
        # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
        ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
    ]
    model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
          shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
          callbacks=callbacks)