pythondeep-learningneural-networkhuggingface-transformershuggingface

Early stopping in Bert Trainer instances


I am fine-tuning a BERT model for a multiclass classification task. My problem is that I don't know how to add "early stopping" to those Trainer instances. Any ideas?


Solution

  • There are a couple of modifications you need to perform, prior to correctly using the EarlyStoppingCallback().

    from transformers import EarlyStoppingCallback, IntervalStrategy
    ...
    ...
    # Defining the TrainingArguments() arguments
    args = TrainingArguments(
       output_dir = "training_with_callbacks",
       evaluation_strategy = IntervalStrategy.STEPS, # "steps"
       eval_steps = 50, # Evaluation and Save happens every 50 steps
       save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
       learning_rate=2e-5,
       per_device_train_batch_size=batch_size,
       per_device_eval_batch_size=batch_size,
       num_train_epochs=5,
       weight_decay=0.01,
       push_to_hub=False,
       metric_for_best_model = 'f1',
       load_best_model_at_end=True)
    

    You need to:

    1. Use load_best_model_at_end = True (EarlyStoppingCallback() requires this to be True).
    2. evaluation_strategy = 'steps' or IntervalStrategy.STEPS instead of 'epoch'.
    3. eval_steps = 50 (evaluate the metrics after N steps).
    4. metric_for_best_model = 'f1'

    In your Trainer():

    trainer = Trainer(
        model,
        args,
        ...
        compute_metrics=compute_metrics,
        callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
    )
    

    Of course, when you use compute_metrics(), for example it can be a function like:

    def compute_metrics(p):    
        pred, labels = p
        pred = np.argmax(pred, axis=1)
        accuracy = accuracy_score(y_true=labels, y_pred=pred)
        recall = recall_score(y_true=labels, y_pred=pred)
        precision = precision_score(y_true=labels, y_pred=pred)
        f1 = f1_score(y_true=labels, y_pred=pred)    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}
    

    The return of the compute_metrics() should be a dictionary and you can access whatever metric you want/compute inside the function and return.

    Note: In newer transformers version, the usage of Enum IntervalStrategy.steps is recommended (see TrainingArguments()) instead of plain steps string, the latter being soon subject to deprecation.