I used the HuggingFace transformers
library to train a BERT model for sequence classification.
The training process is good on GPU, but the evaluation process(which is running GPU) is too slow. For example, when I just have a sanity check for just 20 short text inputs, the evaluation runtime is about 160 seconds per step.
Here's the snippet code:
def compute_metrics(eval_pred):
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1", average="macro")
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
f1_score = f1_metric.compute(predictions=predictions, references=labels, average="macro")
return {**accuracy, **f1_score}
model = AutoModelForSequenceClassification.from_pretrained(
base_model_path,
num_labels=num_labels,
id2label=id2label,
label2id=label2id
)
training_args = TrainingArguments(
output_dir=".",
learning_rate=lr,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=n_epoch,
weight_decay=weight_decay,
evaluation_strategy="steps",
eval_steps=eval_steps,
logging_strategy="steps",
logging_steps=logging_steps,
save_strategy="steps",
save_steps=saving_steps,
load_best_model_at_end=True,
report_to=["tensorboard"],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_ds,
eval_dataset=tokenized_valid_ds,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.train()
The properties of the environment:
transformers 4.29.2
Python 3.10.9
and the configuration of training is like the following:
len(train_data) ~= 36K
len(valid_data) ~= 2K
len(test_data) ~= 2K
model_name = 'bert-base-uncased'
per_device_train_batch_size=16
per_device_eval_batch_size=16
num_train_epochs=30
P.S.: The length of all data is small(less than ten tokens).
Can anyone suggest a solution to reduce the time overhead of the evaluation process?
So I finally got the problem. It's related to evaluate.load()
calls inside the compute_metrics
function. It seems this method has a significant overhead in time, so it shouldn't be inside some functions e.g. compute_metrics
which are called many times. I moved out two load()
methods of compute_metrics
function and it works quickly now.