pythonmachine-learningneural-networkhuggingface-transformershuggingface-datasets

unexpected transformer's dataset structure after set_transform or with_transform


I am using the feature extractor from ViT like explained here.

And noticed a weird behaviour I cannot fully understand.

After loading the dataset as in that colab notebook, I see:

ds['train'].features

{'image_file_path': Value(dtype='string', id=None),  'image':
Image(mode=None, decode=True, id=None),  'labels':
ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'],
id=None)}

And we can assess the features in both ways:

ds['train']['labels'][0:5]

[0, 0, 0, 0, 0]

ds['train'][0:2]

{'image_file_path':
['/home/albert/.cache/huggingface/datasets/downloads/extracted/967f0d9f61a7a8de58892c6fab6f02317c06faf3e19fba6a07b0885a9a7142c7/train/angular_leaf_spot/angular_leaf_spot_train.0.jpg',
'/home/albert/.cache/huggingface/datasets/downloads/extracted/967f0d9f61a7a8de58892c6fab6f02317c06faf3e19fba6a07b0885a9a7142c7/train/angular_leaf_spot/angular_leaf_spot_train.1.jpg'],
'image': [<PIL.JpegImagePlugin.JpegImageFile image mode=RGB
size=500x500>,   <PIL.JpegImagePlugin.JpegImageFile image mode=RGB
size=500x500>],  'labels': [0, 0]}

But after

from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
ds = load_dataset('beans')

def transform(example_batch):
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['labels']
    return inputs

prepared_ds = ds.with_transform(transform)

We see the features are kept:

prepared_ds['train'].features

{'image_file_path': Value(dtype='string', id=None),  'image':
Image(mode=None, decode=True, id=None),  'labels':
ClassLabel(names=['angular_leaf_spot', 'bean_rust', 'healthy'],
id=None)}

prepared_ds['train'][0:2]

{'pixel_values': tensor([[[[-0.5686, -0.5686, -0.5608,  ..., -0.0275, 
0.1843, -0.2471],
...,
[-0.5843, -0.5922, -0.6078,  ...,  0.2627,  0.1608,  0.2000]],

         [[-0.7098, -0.7098, -0.7490,  ..., -0.3725, -0.1608, -0.6000],
          ...,
          [-0.8824, -0.9059, -0.9216,  ..., -0.2549, -0.2000, -0.1216]]],

        [[[-0.5137, -0.4902, -0.4196,  ..., -0.0275, -0.0039, -0.2157],
          ...,
          [-0.5216, -0.5373, -0.5451,  ..., -0.1294, -0.1529, -0.2627]],

         [[-0.1843, -0.2000, -0.1529,  ...,  0.2157,  0.2078, -0.0902],
          ...,
          [-0.7725, -0.7961, -0.8039,  ..., -0.3725, -0.4196, -0.5451]],

         [[-0.7569, -0.8510, -0.8353,  ..., -0.3255, -0.2706, -0.5608],
          ...,
          [-0.5294, -0.5529, -0.5608,  ..., -0.1686, -0.1922, -0.3333]]]]), 'labels': [0, 0]}

But when I try to access the labels directly

prepared_ds['train']['labels']

I got a key error message:

```
--------------------------------------------------------------------------- 
KeyError                                  Traceback (most recent call last) Cell In[32], line 1
----> 1 prepared_ds['train']['labels']

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/arrow_dataset.py:2872, in Dataset.__getitem__(self, key)    2870 def __getitem__(self, key): 
# noqa: F811    2871     """Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
-> 2872     return self._getitem(key)

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/arrow_dataset.py:2857, in Dataset._getitem(self, key, **kwargs)    2855 formatter = get_formatter(format_type, features=self._info.features,
**format_kwargs)    2856 pa_subtable = query_table(self._data, key, indices=self._indices)
-> 2857 formatted_output = format_table(    2858     pa_subtable, key, formatter=formatter, format_columns=format_columns, output_all_columns=output_all_columns    2859 )    2860 return formatted_output

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:639, in format_table(table, key, formatter, format_columns, output_all_columns)
    637 python_formatter = PythonFormatter(features=formatter.features)
    638 if format_columns is None:
--> 639     return formatter(pa_table, query_type=query_type)
    640 elif query_type == "column":
    641     if key in format_columns:

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:405, in Formatter.__call__(self, pa_table, query_type)
    403     return self.format_row(pa_table)
    404 elif query_type == "column":
--> 405     return self.format_column(pa_table)
    406 elif query_type == "batch":
    407     return self.format_batch(pa_table)

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:501, in CustomFormatter.format_column(self, pa_table)
    500 def format_column(self, pa_table: pa.Table) -> ColumnFormat:
--> 501     formatted_batch = self.format_batch(pa_table)
    502     if hasattr(formatted_batch, "keys"):
    503         if len(formatted_batch.keys()) > 1:

File ~/anaconda3/envs/LLM/lib/python3.12/site-packages/datasets/formatting/formatting.py:522, in CustomFormatter.format_batch(self, pa_table)
    520 batch = self.python_arrow_extractor().extract_batch(pa_table)
    521 batch = self.python_features_decoder.decode_batch(batch)
--> 522 return self.transform(batch)

Cell In[12], line 5, in transform(example_batch)
      3 def transform(example_batch):
      4     # Take a list of PIL images and turn them to pixel values
----> 5     inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')
      7     # Don't forget to include the labels!
      8     inputs['labels'] = example_batch['labels']

KeyError: 'image'
```

It sounds like the error is because the feature extractor added 'pixel_values' but the feature is kept as 'image' But it also appears to imply an attempt to re-apply transform...

Also: it is not possible to save the dataset to the disk

    prepared_ds.save_to_disk(img_path)
```
--------------------------------------------------------------------------- 
TypeError                                 Traceback (most recent call last) Cell In[21], line 1
----> 1 dataset.save_to_disk(img_path)

File ~/anaconda3/envs/LLM/lib/python3.13/site-packages/datasets/arrow_dataset.py:1503, in Dataset.save_to_disk(self, dataset_path, max_shard_size, num_shards, num_proc, storage_options)    1501         json.dumps(state["_format_kwargs"][k])    1502     except TypeError as e:
-> 1503         raise TypeError(    1504             str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't."    1505 ) from None    1506 # Get json serializable dataset info    1507 dataset_info = asdict(self._info)

TypeError: Object of type function is not JSON serializable The format kwargs must be JSON serializable, but key 'transform' isn't.
```

Note the original codes in that notebook work perfectly (training, evaluation, etc). I just got this error because I tried to inspect the dataset, try to save the generated dataset, etc. to explore the dataset object...

Shouldn't the dataset structure be accessible in a similar way after with_transform() or set_transform()? Why does it call the transform function again if we just attempt to access one of the features?

I’m hoping you can shed some light on this behaviour...


Solution

  • This is not the way how you pick up the dataset items. First you need to indicate the slice:

    prepared_ds_batch = prepared_ds['train'][0:10]
    

    by using indexing.

    Then you can use the key labels

    prepared_ds_batch['labels']
    [out]: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    

    Regarding the second issue with saving the data: you are not able to save it because of the known issue with transform functions: https://github.com/huggingface/datasets/issues/6221

    You might however save the dataset as prepared_ds.with_format(None).save_to_disk('test_path'). But after loading it again from disk you need to launch again the transform function.

    Edited: You cannot use prepared_ds['train']['labels'] as 'labels' is expected to be integers representing indices of the items.