kerasearly-stopping

Keras: EarlyStopping use validation loss of untrained network instead of Inf


I have implemented early stopping with "ModelCheckpoint". The goal is to select the best trained model version and return the belonging epoch number.

early_stopping = EarlyStopping(monitor='val_loss', patience=earlyStoppingEpochs, restore_best_weights=False) 
checkpoint_filepath = '/tmp/checkpoint.weights.h5'
checkpoint = ModelCheckpoint(filepath=checkpoint_filepath, save_weights_only=True, monitor="val_loss", mode="min", save_best_only=True, verbose=verbose)
history = model.fit(Train_Data, Train_Label, epochs=epochs, batch_size=batch_size, validation_data=(Val_Data, Val_Label), verbose=verbose, callbacks=[early_stopping, checkpoint])
hist = model.history.history['val_loss']
finalEpochs = np.argmin(hist) + 1 
model.load_weights(checkpoint_filepath)

Everything works. The only point that still bothers me is that the validation loss of the untrained model is not taken into account. Instead, the initial loss is set to Inf. "Epoch 1: val_loss improved from **inf **to 2.35898, saving model to /tmp/checkpoint.weights.h5"

In my application, however, it is possible that even the first training epoch leads to a deterioration in the validation loss. Then it should return the untrained model as the best model and the ideal epoch number of 0.

Do you know if it is possible to adjust the behavior so that it does not use inf at the beginning of EarlyStopping, but the validation loss of the untrained model?


Solution

  • You can manually evaluate the model before training begins and set this as the baseline for early stopping. Here is how you can adjust your code:

    1. Evaluate the untrained model on the validation set to get the initial validation loss.
    2. Modify the EarlyStopping callback to use this initial validation loss as the baseline.

    Example:

    initial_val_loss = model.evaluate(Val_Data, Val_Label, verbose=0)
    
    class CustomEarlyStopping(EarlyStopping):
        def __init__(self, initial_val_loss, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.initial_val_loss = initial_val_loss
            self.best = initial_val_loss
    

    Hope this will help.