I am trying to extract the hidden states of a transformer model:
from transformers import AutoModel
import torch
from transformers import AutoTokenizer
model_ckpt = "distilbert-base-uncased"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModel.from_pretrained(model_ckpt).to(device)
from datasets import load_dataset
emotions = load_dataset("emotion", ignore_verifications=True)
# tokenize data
def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True)
emotions_encoded = emotions.map(tokenize, batched=True, batch_size=None)
def extract_hidden_states(batch):
inputs = {k:v.to(device) for k,v in batch.items()
if k in tokenizer.model_input_names}
with torch.no_grad():
last_hidden_state = model(*inputs).last_hidden_state
return{"hidden_state": last_hidden_state[:,0].cpu().numpy()}
# convert input_ids and attention_mask columns to "torch" format
emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
# extract hidden states
emotions_hidden = emotions_encoded.map(extract_hidden_states, batched=True)
However, on running the last line I get the error 'str' object has no attribute 'size'
I've tried deprecating the transformers
package but that didn't fix it. Some posts online indicate it may have to do with the transformer
package will return a dictionary by default, but I don't know how to work around that.
Full error:
AttributeError Traceback (most recent call last)
Cell In[8], line 5
2 emotions_encoded.set_format("torch", columns=["input_ids", "attention_mask", "label"])
4 # extract hidden states
----> 5 emotions_hidden = emotions_encoded.map(extract_hidden_states, batched=True)
File ~\Anaconda3\envs\ml\lib\site-packages\datasets\dataset_dict.py:851, in DatasetDict.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)
848 if cache_file_names is None:
849 cache_file_names = {k: None for k in self}
850 return DatasetDict(
--> 851 {
852 k: dataset.map(
853 function=function,
854 with_indices=with_indices,
855 with_rank=with_rank,
856 input_columns=input_columns,
857 batched=batched,
858 batch_size=batch_size,
859 drop_last_batch=drop_last_batch,
860 remove_columns=remove_columns,
861 keep_in_memory=keep_in_memory,
862 load_from_cache_file=load_from_cache_file,
863 cache_file_name=cache_file_names[k],
864 writer_batch_size=writer_batch_size,
865 features=features,
866 disable_nullable=disable_nullable,
867 fn_kwargs=fn_kwargs,
868 num_proc=num_proc,
869 desc=desc,
870 )
871 for k, dataset in self.items()
872 }
873 )
File ~\Anaconda3\envs\ml\lib\site-packages\datasets\dataset_dict.py:852, in <dictcomp>(.0)
848 if cache_file_names is None:
849 cache_file_names = {k: None for k in self}
850 return DatasetDict(
851 {
--> 852 k: dataset.map(
853 function=function,
854 with_indices=with_indices,
855 with_rank=with_rank,
856 input_columns=input_columns,
857 batched=batched,
858 batch_size=batch_size,
859 drop_last_batch=drop_last_batch,
860 remove_columns=remove_columns,
861 keep_in_memory=keep_in_memory,
862 load_from_cache_file=load_from_cache_file,
863 cache_file_name=cache_file_names[k],
864 writer_batch_size=writer_batch_size,
865 features=features,
866 disable_nullable=disable_nullable,
867 fn_kwargs=fn_kwargs,
868 num_proc=num_proc,
869 desc=desc,
870 )
871 for k, dataset in self.items()
872 }
873 )
File ~\Anaconda3\envs\ml\lib\site-packages\datasets\arrow_dataset.py:578, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
576 self: "Dataset" = kwargs.pop("self")
577 # apply actual function
--> 578 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
579 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
580 for dataset in datasets:
581 # Remove task templates if a column mapping of the template is no longer valid
File ~\Anaconda3\envs\ml\lib\site-packages\datasets\arrow_dataset.py:543, in transmit_format.<locals>.wrapper(*args, **kwargs)
536 self_format = {
537 "type": self._format_type,
538 "format_kwargs": self._format_kwargs,
539 "columns": self._format_columns,
540 "output_all_columns": self._output_all_columns,
541 }
542 # apply actual function
--> 543 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
544 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
545 # re-apply format to the output
File ~\Anaconda3\envs\ml\lib\site-packages\datasets\arrow_dataset.py:3073, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3065 if transformed_dataset is None:
3066 with logging.tqdm(
3067 disable=not logging.is_progress_bar_enabled(),
3068 unit=" examples",
(...)
3071 desc=desc or "Map",
3072 ) as pbar:
-> 3073 for rank, done, content in Dataset._map_single(**dataset_kwargs):
3074 if done:
3075 shards_done += 1
File ~\Anaconda3\envs\ml\lib\site-packages\datasets\arrow_dataset.py:3449, in Dataset._map_single(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)
3445 indices = list(
3446 range(*(slice(i, i + batch_size).indices(shard.num_rows)))
3447 ) # Something simpler?
3448 try:
-> 3449 batch = apply_function_on_filtered_inputs(
3450 batch,
3451 indices,
3452 check_same_num_examples=len(shard.list_indexes()) > 0,
3453 offset=offset,
3454 )
3455 except NumExamplesMismatchError:
3456 raise DatasetTransformationNotAllowedError(
3457 "Using `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn't create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it."
3458 ) from None
File ~\Anaconda3\envs\ml\lib\site-packages\datasets\arrow_dataset.py:3330, in Dataset._map_single.<locals>.apply_function_on_filtered_inputs(pa_inputs, indices, check_same_num_examples, offset)
3328 if with_rank:
3329 additional_args += (rank,)
-> 3330 processed_inputs = function(*fn_args, *additional_args, **fn_kwargs)
3331 if isinstance(processed_inputs, LazyDict):
3332 processed_inputs = {
3333 k: v for k, v in processed_inputs.data.items() if k not in processed_inputs.keys_to_format
3334 }
Cell In[7], line 6, in extract_hidden_states(batch)
3 inputs = {k:v.to(device) for k,v in batch.items()
4 if k in tokenizer.model_input_names}
5 with torch.no_grad():
----> 6 last_hidden_state = model(*inputs).last_hidden_state
7 return{"hidden_state": last_hidden_state[:,0].cpu().numpy()}
File ~\Anaconda3\envs\ml\lib\site-packages\torch\nn\modules\module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~\Anaconda3\envs\ml\lib\site-packages\torch\nn\modules\module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~\Anaconda3\envs\ml\lib\site-packages\transformers\models\distilbert\modeling_distilbert.py:593, in DistilBertModel.forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
591 elif input_ids is not None:
592 self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
--> 593 input_shape = input_ids.size()
594 elif inputs_embeds is not None:
595 input_shape = inputs_embeds.size()[:-1]
AttributeError: 'str' object has no attribute 'size'
The issue is happening when you're filtering the dictionary, extract_hidden_states
in your extract_hidden_states()
function. This dictionary includes keys like 'text' (which contains strings), the function may mistakenly try to .to(device)
on a string, which I'm guessing is causing the error here.
You can modify your function this way:
def extract_hidden_states(batch):
inputs = {k: v for k, v in batch.items() if k in tokenizer.model_input_names}
# Ensure all inputs are tensors before sending them to device
inputs = {k: v.clone().detach().to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs) # Unpacking inputs properly
last_hidden_state = outputs.last_hidden_state
return {"hidden_state": last_hidden_state[:, 0].cpu().numpy()}