I'm using a pretrained model based BERT (github link:DNABERT-2)
It uses AutoModelForSequenceClassification and mosaicml/mosaic-bert-base.
I'm having the problem that I cannot extract the attention. I have read many posts which show ways of dealing with that by activating output_attentions=True in the model, but none of the posts solved the problem.
output
is of length 2 and each element is of shape: torch.Size([1, 7, 768])
and
torch.Size([1, 768])
. When trying to get output.attentions
I get None
.
I'm not sure where to search and what a solution would be.
I'm providing my whole code:
Defining model, trainer, data, tokenizer:
from copy import deepcopy
from sklearn.metrics import precision_recall_fscore_support import wandb from transformers import TrainerCallback
# END NEW import os import csv import json import logging from dataclasses import dataclass, field from typing import Optional, Dict, Sequence, Tuple, List
import torch import transformers import sklearn import numpy as np from torch.utils.data import Dataset
@dataclass class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
use_lora: bool = field(default=False, metadata={"help": "whether to use LoRA"})
lora_r: int = field(default=8, metadata={"help": "hidden dimension for LoRA"})
lora_alpha: int = field(default=32, metadata={"help": "alpha for LoRA"})
lora_dropout: float = field(default=0.05, metadata={"help": "dropout rate for LoRA"})
lora_target_modules: str = field(default="query,value", metadata={"help": "where to perform LoRA"})
@dataclass class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
kmer: int = field(default=-1, metadata={"help": "k-mer for input sequence. -1 means not using k-mer."})
@dataclass class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
run_name: str = field(default="run")
optim: str = field(default="adamw_torch")
model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length."})
gradient_accumulation_steps: int = field(default=1)
per_device_train_batch_size: int = field(default=1)
per_device_eval_batch_size: int = field(default=1)
num_train_epochs: int = field(default=1)
logging_steps: int = field(default=100)
save_steps: int = field(default=100)
fp16: bool = field(default=False)
# START NEW
# eval_steps: int = field(default=100)
eval_steps: int = field(default=0.1)
# END NEW
evaluation_strategy: str = field(default="steps")
warmup_steps: int = field(default=50)
weight_decay: float = field(default=0.01)
learning_rate: float = field(default=1e-4)
save_total_limit: int = field(default=3)
load_best_model_at_end: bool = field(default=True)
output_dir: str = field(default="output")
find_unused_parameters: bool = field(default=False)
checkpointing: bool = field(default=False)
dataloader_pin_memory: bool = field(default=False)
eval_and_save_results: bool = field(default=True)
save_model: bool = field(default=False)
seed: int = field(default=42)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
""" Get the reversed complement of the original DNA sequence. """
def get_alter_of_dna_sequence(sequence: str):
MAP = {"A": "T", "T": "A", "C": "G", "G": "C"}
# return "".join([MAP[c] for c in reversed(sequence)])
return "".join([MAP[c] for c in sequence])
""" Transform a dna sequence to k-mer string """
def generate_kmer_str(sequence: str, k: int) -> str:
"""Generate k-mer string from DNA sequence."""
return " ".join([sequence[i:i + k] for i in range(len(sequence) - k + 1)])
""" Load or generate k-mer string for each DNA sequence. The generated k-mer string will be saved to the same directory as the original data with the same name but with a suffix of "_{k}mer". """
def load_or_generate_kmer(data_path: str, texts: List[str], k: int) -> List[str]:
"""Load or generate k-mer string for each DNA sequence."""
kmer_path = data_path.tokenizerreplace(".csv", f"_{k}mer.json")
if os.path.exists(kmer_path):
logging.warning(f"Loading k-mer from {kmer_path}...")
with open(kmer_path, "r") as f:
kmer = json.load(f)
else:
logging.warning(f"Generating k-mer...")
kmer = [generate_kmer_str(text, k) for text in texts]
with open(kmer_path, "w") as f:
logging.warning(f"Saving k-mer to {kmer_path}...")
json.dump(kmer, f)
return kmer
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
kmer: int = -1):
super(SupervisedDataset, self).__init__()
# load data from the disk
with open(data_path, "r") as f:
data = list(csv.reader(f))[1:]
if len(data[0]) == 2:
# data is in the format of [text, label]
logging.warning("Perform single sequence classification...")
texts = [d[0] for d in data]
labels = [int(d[1]) for d in data]
# All genes sequences are concat: we don't work with the sequence-pair,
# But we are tricking the model to think it is single sequence.
elif len(data[0]) == 3:
# data is in the format of [text1, text2, label]
logging.warning("Perform sequence-pair classification...")
texts = [[d[0], d[1]] for d in data]
labels = [int(d[2]) for d in data]
else:
raise ValueError("Data format not supported.")
if kmer != -1:
# only write file on the first process
if torch.distributed.get_rank() not in [0, -1]:
torch.distributed.barrier()
logging.warning(f"Using {kmer}-mer as input...")
texts = load_or_generate_kmer(data_path, texts, kmer)
if torch.distributed.get_rank() == 0:
torch.distributed.barrier()
output = tokenizer(
texts,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
self.input_ids = output["input_ids"]
# CHANGE
self.input_ids[0][self.input_ids[0] == 0] = 2
# Change to which tokens we want to attend and to which we don't
self.attention_mask = output["attention_mask"]
self.labels = labels
self.num_labels = len(set(labels))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
@dataclass class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.Tensor(labels).long()
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
""" Manually calculate the accuracy, f1, matthews_correlation, precision, recall with sklearn. """
def calculate_metric_with_sklearn(logits: np.ndarray, labels: np.ndarray):
if logits.ndim == 3:
# Reshape logits to 2D if needed
logits = logits.reshape(-1, logits.shape[-1])
predictions = np.argmax(logits, axis=-1)
valid_mask = labels != -100 # Exclude padding tokens (assuming -100 is the padding token ID)
valid_predictions = predictions[valid_mask]
valid_labels = labels[valid_mask]
return {
# START NEW
"sum prediction": f'{sum(valid_predictions)}/{len(valid_predictions)}',
# END NEW
"accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions),
"f1": sklearn.metrics.f1_score(
valid_labels, valid_predictions, average="macro", zero_division=0
),
"matthews_correlation": sklearn.metrics.matthews_corrcoef(
valid_labels, valid_predictions
),
"precision": sklearn.metrics.precision_score(
valid_labels, valid_predictions, average="macro", zero_division=0
),
"recall": sklearn.metrics.recall_score(
valid_labels, valid_predictions, average="macro", zero_division=0
),
}
""" Compute metrics used for huggingface trainer. """ def compute_metrics(eval_pred):
logits, labels = eval_pred
if isinstance(logits, tuple): # Unpack logits if it's a tuple
logits = logits[0]
return calculate_metric_with_sklearn(logits, labels)
class CustomTrainer(transformers.Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.epoch_predictions = []
self.epoch_labels = []
self.epoch_loss = []
def compute_loss(self, model, inputs, return_outputs=False):
"""
MAX: Subclassed to compute training accuracy.
How the loss is computed by Trainer. By default, all models return the loss in
the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs, output_attentions=True)
# TEST
try:
print(f"Attention: {outputs.attentions}")
except Exception:
print("No Attention returned")
if "labels" in inputs:
preds = outputs.logits.detach()
# Log accuracy
acc = (
(preds.argmax(axis=1) == inputs["labels"])
.type(torch.float)
.mean()
.item()
)
# Uncomment it if you want to plot the batch accuracy
# wandb.log({"batch_accuracy": acc}) # Log accuracy
# Store predictions and labels for epoch-level metrics
self.epoch_predictions.append(preds.cpu().numpy())
self.epoch_labels.append(inputs["labels"].cpu().numpy())
# Save past state if it exists
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = self.label_smoother(outputs, labels)
else:
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
# Uncomment it if you want to plot the batch loss
# wandb.log({"batch_loss": loss})
self.epoch_loss.append(loss.item()) # Store loss for epoch-level metrics
return (loss, outputs) if return_outputs else loss
# Define a custom callback to calculate metrics at the end of each epoch class CustomCallback(TrainerCallback):
def __init__(self, trainer) -> None:
super().__init__()
self._trainer = trainer
def on_epoch_end(self, args, state, control, **kwargs):
# Aggregate predictions and labels for the entire epoch
epoch_predictions = np.concatenate(self._trainer.epoch_predictions)
epoch_labels = np.concatenate(self._trainer.epoch_labels)
# Compute accuracy
accuracy = np.mean(epoch_predictions.argmax(axis=1) == epoch_labels)
# Compute mean loss
mean_loss = np.mean(self._trainer.epoch_loss)
# Compute precision, recall, and F1-score
precision, recall, f1, _ = precision_recall_fscore_support(
epoch_labels, epoch_predictions.argmax(axis=1), average="weighted"
)
# Log epoch-level metrics
wandb.log({"epoch_accuracy": accuracy, "epoch_loss": mean_loss})
wandb.log({"precision": precision, "recall": recall, "f1": f1})
# Clear stored predictions, labels, and loss for the next epoch
self._trainer.epoch_predictions = []
self._trainer.epoch_labels = []
self._trainer.epoch_loss = []
return None
# TODO: use this function to gather the prediction and labels and get the metrics
#%%
Instantiating and training:
from transformer_model import SupervisedDataset, DataCollatorForSupervisedDataset, ModelArguments, \
TrainingArguments, DataArguments, safe_save_model_for_hf_trainer, CustomTrainer, CustomCallback, \
compute_metrics
from copy import deepcopy
from transformers import TrainerCallback
# END NEW
import os
import json
import torch
import transformers
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
)
import wandb
run = wandb.init()
assert run is wandb.run
def train(device):
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# load tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=True,
trust_remote_code=True,
)
if "InstaDeepAI" in model_args.model_name_or_path:
tokenizer.eos_token = tokenizer.pad_token
# define datasets and data collator
train_dataset = SupervisedDataset(tokenizer=tokenizer,
data_path=os.path.join(data_args.data_path, "train.csv"),
kmer=data_args.kmer)
val_dataset = SupervisedDataset(tokenizer=tokenizer,
data_path=os.path.join(data_args.data_path, "dev.csv"),
kmer=data_args.kmer)
test_dataset = SupervisedDataset(tokenizer=tokenizer,
data_path=os.path.join(data_args.data_path, "test.csv"),
kmer=data_args.kmer)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
# load model
model = transformers.AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
num_labels=train_dataset.num_labels,
trust_remote_code=True,
output_attentions = True
).to(device)
# configure LoRA
if model_args.use_lora:
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=list(model_args.lora_target_modules.split(",")),
lora_dropout=model_args.lora_dropout,
bias="none",
task_type="SEQ_CLS",
inference_mode=False,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
trainer = CustomTrainer(model=model,
tokenizer=tokenizer,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator
)
trainer.add_callback(CustomCallback(trainer))
trainer.train()
# train_result = trainer.train()
# loss = train_result["loss"]
# print(f"loss issss: {loss}")
# print(f"Train reusults: {train_result}") # NEW: result: only returns metrics at the end of training
if training_args.save_model:
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
# get the evaluation results from trainer
if training_args.eval_and_save_results:
results_path = os.path.join(training_args.output_dir, "results", training_args.run_name)
results = trainer.evaluate(eval_dataset=test_dataset)
os.makedirs(results_path, exist_ok=True)
with open(os.path.join(results_path, "eval_results.json"), "w") as f:
json.dump(results, f)
if __name__ == "__main__":
# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
# Call the train function with the device
train(device)
After training, I try to run it on an example:
model_path = './finetune/output/dnabert2'
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load the model with output_attention=True
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, output_attentions=True)
model_input = tokenizer("ACTGACGGGTAGTGACTG", return_tensors="pt")
with torch.inference_mode():
output = model(**model_input, output_attentions=True)
My code might have some tests and prints. Let me know if anything is missing. Thank you very much for the help.
The problem was deep in the structure: attention was discarded early in the model, I had therefore to go through the code to understand what is happening, and change it.
https://huggingface.co/jaandoui/DNABERT2-AttentionExtracted
I had to extract attention_probs. Here are the changes I have done.