nlphuggingface-transformersbert-language-modeltransformer-modelattention-model

No Attention returned even when output_attentions= True


I'm using a pretrained model based BERT (github link:DNABERT-2)

It uses AutoModelForSequenceClassification and mosaicml/mosaic-bert-base.

I'm having the problem that I cannot extract the attention. I have read many posts which show ways of dealing with that by activating output_attentions=True in the model, but none of the posts solved the problem. output is of length 2 and each element is of shape: torch.Size([1, 7, 768]) and torch.Size([1, 768]). When trying to get output.attentions I get None.

I'm not sure where to search and what a solution would be.

I'm providing my whole code:

Defining model, trainer, data, tokenizer:

from copy import deepcopy

from sklearn.metrics import precision_recall_fscore_support import wandb from transformers import TrainerCallback
# END NEW import os import csv import json import logging from dataclasses import dataclass, field from typing import Optional, Dict, Sequence, Tuple, List

import torch import transformers import sklearn import numpy as np from torch.utils.data import Dataset

@dataclass class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    use_lora: bool = field(default=False, metadata={"help": "whether to use LoRA"})
    lora_r: int = field(default=8, metadata={"help": "hidden dimension for LoRA"})
    lora_alpha: int = field(default=32, metadata={"help": "alpha for LoRA"})
    lora_dropout: float = field(default=0.05, metadata={"help": "dropout rate for LoRA"})
    lora_target_modules: str = field(default="query,value", metadata={"help": "where to perform LoRA"})


@dataclass class DataArguments:
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
    kmer: int = field(default=-1, metadata={"help": "k-mer for input sequence. -1 means not using k-mer."})


@dataclass class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    run_name: str = field(default="run")
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(default=512, metadata={"help": "Maximum sequence length."})
    gradient_accumulation_steps: int = field(default=1)
    per_device_train_batch_size: int = field(default=1)
    per_device_eval_batch_size: int = field(default=1)
    num_train_epochs: int = field(default=1)
    logging_steps: int = field(default=100)
    save_steps: int = field(default=100)
    fp16: bool = field(default=False)
    # START NEW
    # eval_steps: int = field(default=100)
    eval_steps: int = field(default=0.1)
    # END NEW
    evaluation_strategy: str = field(default="steps")
    warmup_steps: int = field(default=50)
    weight_decay: float = field(default=0.01)
    learning_rate: float = field(default=1e-4)
    save_total_limit: int = field(default=3)
    load_best_model_at_end: bool = field(default=True)
    output_dir: str = field(default="output")
    find_unused_parameters: bool = field(default=False)
    checkpointing: bool = field(default=False)
    dataloader_pin_memory: bool = field(default=False)
    eval_and_save_results: bool = field(default=True)
    save_model: bool = field(default=False)
    seed: int = field(default=42)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa


""" Get the reversed complement of the original DNA sequence. """


def get_alter_of_dna_sequence(sequence: str):
    MAP = {"A": "T", "T": "A", "C": "G", "G": "C"}
    # return "".join([MAP[c] for c in reversed(sequence)])
    return "".join([MAP[c] for c in sequence])


""" Transform a dna sequence to k-mer string """


def generate_kmer_str(sequence: str, k: int) -> str:
    """Generate k-mer string from DNA sequence."""
    return " ".join([sequence[i:i + k] for i in range(len(sequence) - k + 1)])


""" Load or generate k-mer string for each DNA sequence. The generated k-mer string will be saved  to the same directory as the original data with the same name but with a suffix of "_{k}mer". """


def load_or_generate_kmer(data_path: str, texts: List[str], k: int) -> List[str]:
    """Load or generate k-mer string for each DNA sequence."""
    kmer_path = data_path.tokenizerreplace(".csv", f"_{k}mer.json")
    if os.path.exists(kmer_path):
        logging.warning(f"Loading k-mer from {kmer_path}...")
        with open(kmer_path, "r") as f:
            kmer = json.load(f)
    else:
        logging.warning(f"Generating k-mer...")
        kmer = [generate_kmer_str(text, k) for text in texts]
        with open(kmer_path, "w") as f:
            logging.warning(f"Saving k-mer to {kmer_path}...")
            json.dump(kmer, f)

    return kmer


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self,
                 data_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 kmer: int = -1):

        super(SupervisedDataset, self).__init__()

        # load data from the disk
        with open(data_path, "r") as f:
            data = list(csv.reader(f))[1:]
        if len(data[0]) == 2:
            # data is in the format of [text, label]
            logging.warning("Perform single sequence classification...")
            texts = [d[0] for d in data]
            labels = [int(d[1]) for d in data]
        # All genes sequences are concat: we don't work with the sequence-pair,
        # But we are tricking the model to think it is single sequence.
        elif len(data[0]) == 3:
            # data is in the format of [text1, text2, label]
            logging.warning("Perform sequence-pair classification...")
            texts = [[d[0], d[1]] for d in data]
            labels = [int(d[2]) for d in data]
        else:
            raise ValueError("Data format not supported.")

        if kmer != -1:
            # only write file on the first process
            if torch.distributed.get_rank() not in [0, -1]:
                torch.distributed.barrier()

            logging.warning(f"Using {kmer}-mer as input...")
            texts = load_or_generate_kmer(data_path, texts, kmer)

            if torch.distributed.get_rank() == 0:
                torch.distributed.barrier()

        output = tokenizer(
            texts,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )

        self.input_ids = output["input_ids"]
        # CHANGE
        self.input_ids[0][self.input_ids[0] == 0] = 2
        # Change to which tokens we want to attend and to which we don't
        self.attention_mask = output["attention_mask"]
        self.labels = labels
        self.num_labels = len(set(labels))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(input_ids=self.input_ids[i], labels=self.labels[i])


@dataclass class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = torch.Tensor(labels).long()
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


""" Manually calculate the accuracy, f1, matthews_correlation, precision, recall with sklearn. """

def calculate_metric_with_sklearn(logits: np.ndarray, labels: np.ndarray):
    if logits.ndim == 3:
        # Reshape logits to 2D if needed
        logits = logits.reshape(-1, logits.shape[-1])
    predictions = np.argmax(logits, axis=-1)
    valid_mask = labels != -100  # Exclude padding tokens (assuming -100 is the padding token ID)
    valid_predictions = predictions[valid_mask]
    valid_labels = labels[valid_mask]
    return {
        # START NEW
        "sum prediction": f'{sum(valid_predictions)}/{len(valid_predictions)}',
        # END NEW
        "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions),
        "f1": sklearn.metrics.f1_score(
            valid_labels, valid_predictions, average="macro", zero_division=0
        ),
        "matthews_correlation": sklearn.metrics.matthews_corrcoef(
            valid_labels, valid_predictions
        ),
        "precision": sklearn.metrics.precision_score(
            valid_labels, valid_predictions, average="macro", zero_division=0
        ),
        "recall": sklearn.metrics.recall_score(
            valid_labels, valid_predictions, average="macro", zero_division=0
        ),
    }

""" Compute metrics used for huggingface trainer. """ def compute_metrics(eval_pred):
    logits, labels = eval_pred
    if isinstance(logits, tuple):  # Unpack logits if it's a tuple
        logits = logits[0]
    return calculate_metric_with_sklearn(logits, labels)


class CustomTrainer(transformers.Trainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epoch_predictions = []
        self.epoch_labels = []
        self.epoch_loss = []

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        MAX: Subclassed to compute training accuracy.

        How the loss is computed by Trainer. By default, all models return the loss in
        the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs, output_attentions=True)
        # TEST
        try:
            print(f"Attention: {outputs.attentions}")
        except Exception:
            print("No Attention returned")

        if "labels" in inputs:
            preds = outputs.logits.detach()

            # Log accuracy
            acc = (
                (preds.argmax(axis=1) == inputs["labels"])
                .type(torch.float)
                .mean()
                .item()
            )
            # Uncomment it if you want to plot the batch accuracy
            # wandb.log({"batch_accuracy": acc})  # Log accuracy

            # Store predictions and labels for epoch-level metrics
            self.epoch_predictions.append(preds.cpu().numpy())
            self.epoch_labels.append(inputs["labels"].cpu().numpy())

        # Save past state if it exists
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            # Uncomment it if you want to plot the batch loss
            # wandb.log({"batch_loss": loss})
            self.epoch_loss.append(loss.item())  # Store loss for epoch-level metrics

        return (loss, outputs) if return_outputs else loss

# Define a custom callback to calculate metrics at the end of each epoch class CustomCallback(TrainerCallback):

    def __init__(self, trainer) -> None:
        super().__init__()
        self._trainer = trainer

    def on_epoch_end(self, args, state, control, **kwargs):
        # Aggregate predictions and labels for the entire epoch
        epoch_predictions = np.concatenate(self._trainer.epoch_predictions)
        epoch_labels = np.concatenate(self._trainer.epoch_labels)

        # Compute accuracy
        accuracy = np.mean(epoch_predictions.argmax(axis=1) == epoch_labels)

        # Compute mean loss
        mean_loss = np.mean(self._trainer.epoch_loss)

        # Compute precision, recall, and F1-score
        precision, recall, f1, _ = precision_recall_fscore_support(
            epoch_labels, epoch_predictions.argmax(axis=1), average="weighted"
        )

        # Log epoch-level metrics
        wandb.log({"epoch_accuracy": accuracy, "epoch_loss": mean_loss})
        wandb.log({"precision": precision, "recall": recall, "f1": f1})

        # Clear stored predictions, labels, and loss for the next epoch
        self._trainer.epoch_predictions = []
        self._trainer.epoch_labels = []
        self._trainer.epoch_loss = []
        return None

        # TODO: use this function to gather the prediction and labels and get the metrics
#%%

Instantiating and training:

from transformer_model import SupervisedDataset, DataCollatorForSupervisedDataset, ModelArguments, \
    TrainingArguments, DataArguments, safe_save_model_for_hf_trainer, CustomTrainer, CustomCallback, \
    compute_metrics


from copy import deepcopy
from transformers import TrainerCallback
# END NEW
import os
import json
import torch
import transformers

from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
)

import wandb
run = wandb.init()
assert run is wandb.run

def train(device):
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # load tokenizer
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=True,
        trust_remote_code=True,
    )

    if "InstaDeepAI" in model_args.model_name_or_path:
        tokenizer.eos_token = tokenizer.pad_token

    # define datasets and data collator
    train_dataset = SupervisedDataset(tokenizer=tokenizer,
                                      data_path=os.path.join(data_args.data_path, "train.csv"),
                                      kmer=data_args.kmer)
    val_dataset = SupervisedDataset(tokenizer=tokenizer,
                                    data_path=os.path.join(data_args.data_path, "dev.csv"),
                                    kmer=data_args.kmer)
    test_dataset = SupervisedDataset(tokenizer=tokenizer,
                                     data_path=os.path.join(data_args.data_path, "test.csv"),
                                     kmer=data_args.kmer)
    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

    # load model
    model = transformers.AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        num_labels=train_dataset.num_labels,
        trust_remote_code=True,
        output_attentions = True
    ).to(device)

    # configure LoRA
    if model_args.use_lora:
        lora_config = LoraConfig(
            r=model_args.lora_r,
            lora_alpha=model_args.lora_alpha,
            target_modules=list(model_args.lora_target_modules.split(",")),
            lora_dropout=model_args.lora_dropout,
            bias="none",
            task_type="SEQ_CLS",
            inference_mode=False,
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()

    trainer = CustomTrainer(model=model,
                            tokenizer=tokenizer,
                            args=training_args,
                            compute_metrics=compute_metrics,
                            train_dataset=train_dataset,
                            eval_dataset=val_dataset,
                            data_collator=data_collator
                            )

    trainer.add_callback(CustomCallback(trainer))
    trainer.train()

    # train_result = trainer.train()
    # loss = train_result["loss"]
    # print(f"loss issss: {loss}")
    # print(f"Train reusults: {train_result}") # NEW: result: only returns metrics at the end of training

    if training_args.save_model:
        trainer.save_state()
        safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)

    # get the evaluation results from trainer
    if training_args.eval_and_save_results:
        results_path = os.path.join(training_args.output_dir, "results", training_args.run_name)
        results = trainer.evaluate(eval_dataset=test_dataset)
        os.makedirs(results_path, exist_ok=True)
        with open(os.path.join(results_path, "eval_results.json"), "w") as f:
            json.dump(results, f)


if __name__ == "__main__":
    # Define device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)
    # Call the train function with the device
    train(device)

After training, I try to run it on an example:

model_path = './finetune/output/dnabert2'
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load the model with output_attention=True
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, output_attentions=True)
model_input = tokenizer("ACTGACGGGTAGTGACTG", return_tensors="pt")

with torch.inference_mode():
  output = model(**model_input, output_attentions=True)

My code might have some tests and prints. Let me know if anything is missing. Thank you very much for the help.


Solution

  • The problem was deep in the structure: attention was discarded early in the model, I had therefore to go through the code to understand what is happening, and change it.

    https://huggingface.co/jaandoui/DNABERT2-AttentionExtracted

    I had to extract attention_probs. Here are the changes I have done.