I am using the feature extractor from ViT like explained here.
And noticed a weird behaviour I cannot fully understand.
After loading the dataset as in that colab notebook, I see:
ds['train'].features
{'image_file_path': Value(dtype='string', id=None), 'image':
Image(mode=None, decode=True, id=None), 'labels':
ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'],
id=None)}
And we can assess the features in both ways:
ds['train']['labels'][0:5]
[0, 0, 0, 0, 0]
ds['train'][0:2]
{'image_file_path':
['/home/albert/.cache/huggingface/datasets/downloads/extracted/967f0d9f61a7a8de58892c6fab6f02317c06faf3e19fba6a07b0885a9a7142c7/train/angular_leaf_spot/angular_leaf_spot_train.0.jpg',
'/home/albert/.cache/huggingface/datasets/downloads/extracted/967f0d9f61a7a8de58892c6fab6f02317c06faf3e19fba6a07b0885a9a7142c7/train/angular_leaf_spot/angular_leaf_spot_train.1.jpg'],
'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB
size=500x500>, <PIL.JpegImagePlugin.JpegImageFile image mode=RGB
size=500x500>], 'labels': [0, 0]}
But after
from transformers import ViTFeatureExtractor
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
ds = load_dataset('beans')
def transform(example_batch):
inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
inputs['labels'] = example_batch['labels']
return inputs
prepared_ds = ds.with_transform(transform)
We see the features are kept:
prepared_ds['train'].features
{'image_file_path': Value(dtype='string', id=None), 'image':
Image(mode=None, decode=True, id=None), 'labels':
ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'],
id=None)}
prepared_ds['train'][0:2]
{'pixel_values': tensor([[[[-0.5686, -0.5686, -0.5608, ..., -0.0275,
0.1843, -0.2471],
...,
[-0.5843, -0.5922, -0.6078, ..., 0.2627, 0.1608, 0.2000]],
[[-0.7098, -0.7098, -0.7490, ..., -0.3725, -0.1608, -0.6000],
...,
[-0.8824, -0.9059, -0.9216, ..., -0.2549, -0.2000, -0.1216]]],
[[[-0.5137, -0.4902, -0.4196, ..., -0.0275, -0.0039, -0.2157],
...,
[-0.5216, -0.5373, -0.5451, ..., -0.1294, -0.1529, -0.2627]],
[[-0.1843, -0.2000, -0.1529, ..., 0.2157, 0.2078, -0.0902],
...,
[-0.7725, -0.7961, -0.8039, ..., -0.3725, -0.4196, -0.5451]],
[[-0.7569, -0.8510, -0.8353, ..., -0.3255, -0.2706, -0.5608],
...,
[-0.5294, -0.5529, -0.5608, ..., -0.1686, -0.1922, -0.3333]]]]), 'labels': [0, 0]}
But when I try to access the labels directly
prepared_ds['train']['labels']
I got a key error message:
```
---------------------------------------------------------------------------
KeyError Traceback (most recent call last) Cell In[32], line 1
----> 1 prepared_ds['train']['labels']
File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/arrow_dataset.py:2872, in Dataset.__getitem__(self, key) 2870 def __getitem__(self, key):
# noqa: F811 2871 """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 2872 return self._getitem(key)
File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/arrow_dataset.py:2857, in Dataset._getitem(self, key, **kwargs) 2855 formatter = get_formatter(format_type, features=self._info.features,
**format_kwargs) 2856 pa_subtable = query_table(self._data, key, indices=self._indices)
-> 2857 formatted_output = format_table( 2858 pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns 2859 ) 2860 return formatted_output
File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:639, in format_table(table, key, formatter, format_columns, output_all_columns)
637 python_formatter = PythonFormatter(features=formatter.features)
638 if format_columns is None:
--> 639 return formatter(pa_table, query_type=query_type)
640 elif query_type == "column":
641 if key in format_columns:
File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:405, in Formatter.__call__(self, pa_table, query_type)
403 return self.format_row(pa_table)
404 elif query_type == "column":
--> 405 return self.format_column(pa_table)
406 elif query_type == "batch":
407 return self.format_batch(pa_table)
File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:501, in CustomFormatter.format_column(self, pa_table)
500 def format_column(self, pa_table: pa.Table) -> ColumnFormat:
--> 501 formatted_batch = self.format_batch(pa_table)
502 if hasattr(formatted_batch, "keys"):
503 if len(formatted_batch.keys()) > 1:
File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:522, in CustomFormatter.format_batch(self, pa_table)
520 batch = self.python_arrow_extractor().extract_batch(pa_table)
521 batch = self.python_features_decoder.decode_batch(batch)
--> 522 return self.transform(batch)
Cell In[12], line 5, in transform(example_batch)
3 def transform(example_batch):
4 # Take a list of PIL images and turn them to pixel values
----> 5 inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
7 # Don't forget to include the labels!
8 inputs['labels'] = example_batch['labels']
KeyError: 'image'
```
It sounds like the error is because the feature extractor added 'pixel_values' but the feature is kept as 'image'
But it also appears to imply an attempt to re-apply transform
...
Also: it is not possible to save the dataset to the disk
prepared_ds.save_to_disk(img_path)
```
---------------------------------------------------------------------------
TypeError Traceback (most recent call last) Cell In[21], line 1
----> 1 dataset.save_to_disk(img_path)
File ~/anaconda3/envs/LLM/lib/python3.13/site-packages/datasets/arrow_dataset.py:1503, in Dataset.save_to_disk(self, dataset_path, max_shard_size, num_shards, num_proc, storage_options) 1501 json.dumps(state["_format_kwargs"][k]) 1502 except TypeError as e:
-> 1503 raise TypeError( 1504 str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't." 1505 ) from None 1506 # Get json serializable dataset info 1507 dataset_info = asdict(self._info)
TypeError: Object of type function is not JSON serializable The format kwargs must be JSON serializable, but key 'transform' isn't.
```
Note the original codes in that notebook work perfectly (training, evaluation, etc). I just got this error because I tried to inspect the dataset, try to save the generated dataset, etc. to explore the dataset object...
Shouldn't the dataset structure be accessible in a similar way after with_transform()
or set_transform()
? Why does it call the transform function again if we just attempt to access one of the features?
I’m hoping you can shed some light on this behaviour...
This is not the way how you pick up the dataset items. First you need to indicate the slice:
prepared_ds_batch = prepared_ds['train'][0:10]
by using indexing.
Then you can use the key labels
prepared_ds_batch['labels']
[out]: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Regarding the second issue with saving the data: you are not able to save it because of the known issue with transform functions: https://github.com/huggingface/datasets/issues/6221
You might however save the dataset as prepared_ds.with_format(None).save_to_disk('test_path')
. But after loading it again from disk you need to launch again the transform function.
Edited: You cannot use prepared_ds['train']['labels']
as 'labels' is expected to be integers representing indices of the items.