I have used the following code to do sft:
base_model = "google/gemma-3-270m"
it_model = "google/gemma-3-270m-it"
checkpoint_dir = "checkpoint"
learning_rate = 5e-5 #@param {type:"number"}
from datasets import Dataset
from transformers import pipeline
from trl import SFTConfig
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import random
with open(r"PersianNer.json", "r", encoding="utf-8") as f:
data_list = json.load(f)
random.shuffle(data_list)
dataset = Dataset.from_list(data_list)
print(len(dataset))
del data_list
dataset = dataset.train_test_split(test_size=0.2, shuffle=False)
# Print formatted user prompt
print(dataset["train"][0]["messages"])
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)
it_tokenizer = AutoTokenizer.from_pretrained(it_model)
tokenizer.chat_template = it_tokenizer.chat_template
print(f"Device: {model.device}")
print(f"DType: {model.dtype}")
torch_dtype = model.dtype
args = SFTConfig(
output_dir=checkpoint_dir, # directory to save and repository id
max_length=1024, # max sequence length for model and packing of the dataset
packing=False, # Groups multiple samples in the dataset into a single sequence
num_train_epochs=5, # number of training epochs
per_device_train_batch_size=4, # batch size per device during training
gradient_checkpointing=False, # Caching is incompatible with gradient checkpointing
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=1, # log every step
save_strategy="epoch", # save checkpoint every epoch
eval_strategy="epoch", # evaluate checkpoint every epoch
learning_rate=learning_rate, # learning rate
fp16=True if torch_dtype == torch.float16 else False, # use float16 precision
bf16=True if torch_dtype == torch.bfloat16 else False, # use bfloat16 precision
lr_scheduler_type="constant", # use constant learning rate scheduler
dataset_kwargs={
"add_special_tokens": False, # Template with special tokens
"append_concat_token": True, # Add EOS token as separator token between examples
}
)
from trl import SFTTrainer
# Create Trainer object
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=dataset['train'],
eval_dataset=dataset['test'],
processing_class=tokenizer,
)
trainer.train()
# Save the final model again to the Hugging Face Hub
trainer.save_model()
import matplotlib.pyplot as plt
# Access the log history
log_history = trainer.state.log_history
# Extract training / validation loss
train_losses = [log["loss"] for log in log_history if "loss" in log]
epoch_train = [log["epoch"] for log in log_history if "loss" in log]
eval_losses = [log["eval_loss"] for log in log_history if "eval_loss" in log]
epoch_eval = [log["epoch"] for log in log_history if "eval_loss" in log]
# Plot the training loss
plt.plot(epoch_train, train_losses, label="Training Loss")
plt.plot(epoch_eval, eval_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss per Epoch")
plt.legend()
plt.grid(True)
plt.show()
from transformers import AutoTokenizer, AutoModelForCausalLM
model_id = checkpoint_dir
# Load Model
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
from transformers import pipeline
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
def test(test_sample):
# Convert as test example into a prompt with the Gemma template
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:1], tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, max_new_tokens=256, disable_compile=True)
# Extract the user query and original answer
print(f"Question:\n{test_sample['messages'][0]['content']}")
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
print("-"*80)
# Test with an unseen dataset
for item in dataset['test']:
test(item)
The problem is that after fine tuning it generate repetitive model turns like:
Original Answer:
[{'text': 'روسیه', 'type': 'LOC'}]
Generated Answer:
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text': 'روسیه', 'type': 'LOC'}]
model
[{'text':
it seems that it generate up to near of max tokens. Now, if I just the model to it version i.e. "google/gemma-3-270m-it" it works fine and stop at the end of first turn.
For the future reference:
The problem is in
model.generation_config.eos_token_id
which in the raw model is like this:
[1]
But in the it version:
[1, 106]
106 is the <end_of_turn> token.