I have created a custom dataset and trained on it a custom T5ForConditionalGeneration
model that predicts solutions to quadratic equations like this:
Input: "4*x^2 + 4*x + 1"
Output: D = 4 ^ 2 - 4 * 4 * 1 4 * 1 4 * 1 4 * 1 4 * 1 4
I need to get accuracy for this model but I get only loss when I use Trainer
so I used a custom metric function (I didn't write it but took it from a similar project):
def compute_metrics4token(eval_pred):
batch_size = 4
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them.
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Rouge expects a newline after each sentence
decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
answer_accuracy = []
token_accuracy = []
num_correct, num_total = 0, 0
num_answer = 0
number_eq = 0
for p, l in zip(decoded_preds, decoded_labels):
text_pred = p.split(' ')
text_labels = l.split(' ')
m = min(len(text_pred), len(text_labels))
if np.array_equal(text_pred, text_labels):
num_answer += 1
for i, j in zip(text_pred, text_labels):
if i == j:
num_correct += 1
num_total += len(text_labels)
number_eq += 1
token_accuracy = num_correct / num_total
answer_accuracy = num_answer / number_eq
result = {'token_acc': token_accuracy, 'answer_acc': answer_accuracy}
result = {key: value for key, value in result.items()}
for key, value in result.items():
wandb.log({key: value})
return {k: round(v, 4) for k, v in result.items()}
Problem is that it doesn't work and I don't really understand why and what can I do to get accuracy for my model. I get this error when I use the function:
args = Seq2SeqTrainingArguments(
output_dir='./',
num_train_epochs=10,
overwrite_output_dir = True,
evaluation_strategy = 'steps',
learning_rate = 1e-4,
logging_steps = 100,
eval_steps = 100,
save_steps = 100,
load_best_model_at_end = True,
push_to_hub=True,
weight_decay = 0.01,
per_device_train_batch_size=8,
per_device_eval_batch_size=4
)
trainer = Seq2SeqTrainer(model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=args,
data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics4token)
<ipython-input-48-ff7980f6dd66> in compute_metrics4token(eval_pred)
4 # predictions = np.argmax(logits[0])
5 # print(predictions)
----> 6 decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
7 # Replace -100 in the labels as we can't decode them.
8 labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in batch_decode(self, sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
3444 `List[str]`: The list of decoded sentences.
3445 """
-> 3446 return [
3447 self.decode(
3448 seq,
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in <listcomp>(.0)
3445 """
3446 return [
-> 3447 self.decode(
3448 seq,
3449 skip_special_tokens=skip_special_tokens,
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py in decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
3484 token_ids = to_py_obj(token_ids)
3485
-> 3486 return self._decode(
3487 token_ids=token_ids,
3488 skip_special_tokens=skip_special_tokens,
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_fast.py in _decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
547 if isinstance(token_ids, int):
548 token_ids = [token_ids]
--> 549 text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
550
551 clean_up_tokenization_spaces = (
TypeError: argument 'ids': 'list' object cannot be interpreted as an integer
When I print out predictions
I get a tuple:
(array([[[-32.777344, -34.593437, -36.065685, ..., -34.78577 ,
-34.77546 , -34.061115],
[-58.633934, -32.23472 , -31.735909, ..., -40.335655,
-40.28701 , -37.208904],
[-56.650974, -33.564095, -34.409576, ..., -36.94467 ,
-43.246735, -37.469246],
...,
[-56.62741 , -24.561722, -34.11228 , ..., -35.34798 ,
-42.287125, -38.889412],
[-56.632545, -24.470266, -34.0792 , ..., -35.313175,
-42.235626, -38.891712],
[-56.687027, -24.391508, -34.12526 , ..., -35.30828 ,
-42.204193, -38.88395 ]],
[[-29.79866 , -32.22621 , -32.689865, ..., -32.106445,
-31.46681 , -31.706667],
[-62.101192, -33.327423, -30.900173, ..., -38.046883,
-42.26345 , -38.97748 ],
[-54.726807, -29.13115 , -30.294558, ..., -28.370876,
-41.23722 , -37.91609 ],
...,
[-57.279373, -23.954525, -34.066246, ..., -35.047447,
-41.599922, -38.489853],
[-57.31298 , -23.879845, -34.0837 , ..., -35.03614 ,
-41.557755, -38.530064],
[-57.39132 , -23.831306, -34.120094, ..., -35.039547,
-41.525337, -38.55728 ]],
[[-29.858566, -32.452713, -34.05892 , ..., -33.93065 ,
-32.109177, -32.874695],
[-61.375793, -33.656853, -32.95248 , ..., -42.28087 ,
-42.637173, -39.21142 ],
[-58.43721 , -32.496166, -36.44046 , ..., -39.33864 ,
-42.139664, -38.695328],
...,
[-59.654663, -24.117435, -34.266438, ..., -35.734142,
-40.55384 , -38.467537],
[-38.54418 , -18.533113, -29.775307, ..., -26.856483,
-33.07976 , -29.934727],
[-27.716005, -14.610603, -23.752686, ..., -21.140053,
-26.855148, -24.429493]],
...,
[[-33.252697, -34.72487 , -36.395184, ..., -36.87368 ,
-35.207897, -34.468285],
[-59.911736, -32.730076, -32.622803, ..., -43.382267,
-42.25615 , -38.35135 ],
[-54.982887, -31.847572, -32.773827, ..., -38.500675,
-43.97969 , -37.41088 ],
...,
[-56.896988, -23.213766, -34.04734 , ..., -35.88832 ,
-42.176086, -38.953568],
[-56.994152, -23.141619, -34.054848, ..., -35.875816,
-42.176453, -38.97729 ],
[-57.076714, -23.05831 , -34.048904, ..., -35.888298,
-42.165287, -39.020435]],
[[-30.070187, -32.049232, -34.63928 , ..., -35.02118 ,
-32.14465 , -32.891876],
[-61.720093, -32.994057, -32.988144, ..., -42.054638,
-42.18583 , -38.990112],
[-57.74364 , -31.431454, -35.969643, ..., -38.593002,
-42.276768, -38.895355],
...,
[-58.677704, -23.567434, -35.6751 , ..., -36.018696,
-40.343582, -38.681267],
[-58.682228, -23.563087, -35.668964, ..., -36.019753,
-40.336178, -38.67661 ],
[-58.718002, -23.609531, -35.67758 , ..., -36.001644,
-40.366055, -38.67864 ]],
[[-30.320919, -33.430378, -34.84311 , ..., -37.259563,
-32.59662 , -33.03912 ],
[-61.275875, -34.824192, -34.07767 , ..., -44.637024,
-41.718002, -38.974827],
[-54.49349 , -30.689342, -35.539658, ..., -39.984665,
-39.87059 , -37.038437],
...,
[-58.939384, -23.831846, -34.525368, ..., -35.930893,
-40.29633 , -37.637936],
[-58.95117 , -23.824234, -34.520042, ..., -35.931396,
-40.297188, -37.636852],
[-58.966076, -23.795956, -34.519627, ..., -35.901787,
-40.261116, -37.612514]]], dtype=float32), array([[[-1.43104442e-03, -2.98473001e-01, 9.49775204e-02, ...,
-1.77978892e-02, 1.79805323e-01, 1.33578405e-01],
[-2.35560730e-01, 1.53045550e-01, 5.15255742e-02, ...,
-1.57466665e-01, 3.49459350e-01, 7.28092641e-02],
[ 1.60562042e-02, -1.40354022e-01, 5.29232398e-02, ...,
-2.38162443e-01, -7.72500336e-02, 6.80136457e-02],
...,
[ 7.33550191e-02, -3.35853845e-01, 2.25579832e-03, ...,
-1.93636306e-02, 1.08121082e-01, 5.24416938e-02],
[ 8.32231194e-02, -3.11688155e-01, -2.13681534e-02, ...,
3.23344418e-03, 1.08062990e-01, 7.20862746e-02],
[ 9.58326831e-02, -3.00361574e-01, -3.02627794e-02, ...,
3.01265554e-03, 1.20107472e-01, 9.56629887e-02]],
[[-1.16950013e-01, -3.43173921e-01, 1.87818244e-01, ...,
-2.71256089e-01, 7.42092952e-02, 5.77520356e-02],
[-1.62564963e-01, -3.87467295e-01, 1.71134964e-01, ...,
-7.83916116e-02, -3.65173034e-02, 2.08234787e-01],
[-3.71523261e-01, -8.74521434e-02, 1.39187068e-01, ...,
-3.08779895e-01, 3.88156146e-01, 9.99216512e-02],
...,
[ 2.14628279e-02, -3.35561454e-01, -3.76663893e-03, ...,
-1.29795140e-02, 1.44181430e-01, 1.15508482e-01],
[ 3.47745977e-02, -3.30934107e-01, 1.10013550e-02, ...,
-1.84394475e-02, 1.52143195e-01, 1.38157398e-01],
[ 3.02720107e-02, -3.37626845e-01, 1.35379741e-02, ...,
-3.80427912e-02, 1.50906458e-01, 1.38765752e-01]],
[[-6.50129542e-02, -2.63762653e-01, 2.16862872e-01, ...,
-1.66922837e-01, 1.09285273e-01, -6.40013069e-02],
[-5.23199737e-01, -2.32228413e-01, 1.44963071e-01, ...,
-1.41557693e-01, 1.90811172e-01, -2.22496167e-01],
[-2.24985227e-01, -3.69372189e-01, 7.32450858e-02, ...,
6.57786876e-02, 9.70033705e-02, 7.83021152e-02],
...,
[-1.93579309e-03, -3.92921537e-01, -1.28203649e-02, ...,
-8.74079913e-02, 1.13596492e-01, 9.25250202e-02],
[ 4.55581211e-03, -3.65802884e-01, -2.60831695e-02, ...,
-4.12549600e-02, 1.17429778e-01, 1.05997331e-01],
[ 2.46201381e-02, -3.47863257e-01, -4.48134281e-02, ...,
-2.53352951e-02, 1.16753690e-01, 1.36296600e-01]],
...,
[[-6.47678748e-02, -3.45555365e-01, 7.19114989e-02, ...,
-9.16809738e-02, 2.15520635e-01, 1.01671875e-01],
[-7.61077851e-02, -1.51827012e-03, 9.52102616e-02, ...,
-1.39335945e-01, 1.05894208e-01, 3.23191588e-03],
[-3.24888170e-01, -2.17741728e-03, 5.32661797e-03, ...,
-2.78430730e-01, 3.59415114e-01, 1.19439401e-01],
...,
[ 6.89201057e-02, -3.63149673e-01, 7.96841756e-02, ...,
-3.25191446e-04, 1.26513481e-01, 1.36511743e-01],
[ 8.16355348e-02, -3.54205281e-01, 7.69739375e-02, ...,
-2.90949806e-03, 1.31863236e-01, 1.56503588e-01],
[ 8.36645439e-02, -3.38536322e-01, 8.00612345e-02, ...,
-9.39210225e-03, 1.29102767e-01, 1.64855778e-01]],
[[-1.63163885e-01, -3.34902078e-01, 1.11728966e-01, ...,
-1.10363133e-01, 1.19786285e-01, -9.18702483e-02],
[-3.36889774e-01, -3.34888607e-01, 1.30680993e-01, ...,
1.22191897e-03, 1.45059675e-01, -1.27688542e-01],
[-5.92090450e-02, -2.07585752e-01, 2.05589265e-01, ...,
-6.80094585e-02, 2.11224273e-01, 3.92790437e-01],
...,
[ 4.86238785e-02, -4.19503808e-01, -3.39424387e-02, ...,
-1.76134892e-02, 1.00283481e-01, 1.38210282e-01],
[ 5.81516996e-02, -4.04477298e-01, -4.19086292e-02, ...,
-1.02474755e-02, 1.06062084e-01, 1.59754634e-01],
[ 6.70261905e-02, -3.86263877e-01, -4.19785343e-02, ...,
9.05385148e-03, 1.01594023e-01, 1.69663757e-01]],
[[-1.22184128e-01, -3.67584258e-01, 3.60302597e-01, ...,
-4.39502299e-02, 1.33717149e-01, 1.53699834e-02],
[-3.37780178e-01, -4.05100137e-01, 2.02614054e-01, ...,
-5.41410968e-02, 1.55447468e-01, -9.28792357e-02],
[ 1.81227952e-01, -2.29236633e-01, 2.40814224e-01, ...,
1.39913429e-02, 7.61386827e-02, 3.62152725e-01],
...,
[ 1.47830993e-02, -4.26465064e-01, -1.54972840e-02, ...,
3.74358669e-02, 1.52016997e-01, 1.53155088e-01],
[ 3.46656404e-02, -4.00052220e-01, -3.53843644e-02, ...,
2.64652576e-02, 1.62517026e-01, 1.66649833e-01],
[ 4.50411513e-02, -3.61773074e-01, -5.50217964e-02, ...,
3.68298292e-02, 1.67936400e-01, 1.76781893e-01]]],
dtype=float32))
I thought that maybe I need to take argmax from these values but then I still get errors.
If something is unclear I would be happy to provide additional information. Thanks for any help.
EDIT:
I am adding an example of an item in the dataset:
dataset['test'][0:5]
{'text': ['3*x^2 + 9*x + 6 = 0',
'59*x^2 + -59*x + 14 = 0',
'-10*x^2 + 0*x + 0 = 0',
'3*x^2 + 63*x + 330 = 0',
'1*x^2 + -25*x + 156 = 0'],
'label': ['D = 9^2 - 4 * 3 * 6 = 9; x1 = (-9 + (9)**0.5) // (2 * 3)
= -1.0; x2 = (-9 - (9)**0.5) // (2 * 3) = -2.0',
'D = -59^2 - 4 * 59 * 14 = 177; x1 = (59 + (177)**0.5) // (2 * 59)
= 0.0; x2 = (59 - (177)**0.5) // (2 * 59) = 0.0',
'D = 0^2 - 4 * -10 * 0 = 0; x = 0^2 // (2 * -10) = 0',
'D = 63^2 - 4 * 3 * 330 = 9; x1 = (-63 + (9)**0.5) // (2 * 3) =
-10.0; x2 = (-63 - (9)**0.5) // (2 * 3) = -11.0',
'D = -25^2 - 4 * 1 * 156 = 1; x1 = (25 + (1)**0.5) // (2 * 1) =
13.0; x2 = (25 - (1)**0.5) // (2 * 1) = 12.0'],
'__index_level_0__': [10803, 14170, 25757, 73733, 25059]}
It seems like the task you're trying to achieve is some sort of "translation" task so the most appropriate model is to use the AutoModelForSeq2SeqLM
.
And in the case of unspecified sequence, it might be more appropriate to use
You can take a look at various translation-related metrics on https://www.kaggle.com/code/alvations/huggingface-evaluate-for-mt-evaluations
To read the data, you'll have to make sure that the model's forward function
{"text": [0, 1, 2, ... ], "labels": [0, 9, 8, ...]}
in your datasets.Dataset
objectDataCollatorForSeq2Seq
And here's a working snippet of how the code (in parts) can be ran: https://www.kaggle.com/alvations/how-to-train-a-t5-seq2seq-model-using-custom-data
from datasets import Dataset
import evaluate
from transformers import AutoModelForSeq2SeqLM, Trainer, AutoTokenizer, DataCollatorForSeq2Seq
math_data = {'text': ['3*x^2 + 9*x + 6 = 0',
'59*x^2 + -59*x + 14 = 0',
'-10*x^2 + 0*x + 0 = 0',
'3*x^2 + 63*x + 330 = 0',
'1*x^2 + -25*x + 156 = 0'],
'target': ['D = 9^2 - 4 * 3 * 6 = 9; x1 = (-9 + (9)**0.5) // (2 * 3) = -1.0; x2 = (-9 - (9)**0.5) // (2 * 3) = -2.0',
'D = -59^2 - 4 * 59 * 14 = 177; x1 = (59 + (177)**0.5) // (2 * 59) = 0.0; x2 = (59 - (177)**0.5) // (2 * 59) = 0.0',
'D = 0^2 - 4 * -10 * 0 = 0; x = 0^2 // (2 * -10) = 0',
'D = 63^2 - 4 * 3 * 330 = 9; x1 = (-63 + (9)**0.5) // (2 * 3) = -10.0; x2 = (-63 - (9)**0.5) // (2 * 3) = -11.0',
'D = -25^2 - 4 * 1 * 156 = 1; x1 = (25 + (1)**0.5) // (2 * 1) = 13.0; x2 = (25 - (1)**0.5) // (2 * 1) = 12.0']}
math_data_eval = {'text': ["10 + 9x(x+3y) - 3x^3"], "target": ["10 + 9x^2 + 27xy - 3x^3"]}
ds_train = Dataset.from_dict(math_data)
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained("t5-small")
data_collator = DataCollatorForSeq2Seq(tokenizer)
ds_train = ds_train.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=512)
)
ds_train = ds_train.map(lambda y:
{"labels": tokenizer(y["target"], truncation=True, padding="max_length", max_length=512)['input_ids']}
)
ds_eval = Dataset.from_dict(math_data_eval)
ds_eval = ds_eval.map(lambda x: tokenizer(x["text"],
truncation=True, padding="max_length", max_length=512))
ds_eval = ds_eval.map(lambda y:
{"labels": tokenizer(y["target"], truncation=True, padding="max_length", max_length=512)['input_ids']}
)
import numpy as np
metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
labels = [[label.strip()] for label in labels]
return preds, labels
def compute_metrics(eval_preds):
preds, labels = eval_preds
if isinstance(preds, tuple):
preds = preds[0]
# Replace -100s used for padding as we can't decode them
preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
# Some simple post-processing
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
result = {"bleu": result["score"]}
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
result["gen_len"] = np.mean(prediction_lens)
result = {k: round(v, 4) for k, v in result.items()}
return result
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
output_dir="./",
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
logging_steps=2, # set to 1000 for full training
save_steps=16, # set to 500 for full training
eval_steps=4, # set to 8000 for full training
warmup_steps=1, # set to 2000 for full training
max_steps=16, # delete for full training
# overwrite_output_dir=True,
save_total_limit=1,
#fp16=True,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=ds_train.with_format("torch"),
eval_dataset=ds_eval.with_format("torch"),
data_collator=data_collator,
compute_metrics=compute_metrics
)
trainer.train()