In torch's Dataset
, on top of the obligatory __getitem__
method, you can implement the __getitems__
method.
In my case __getitem__
returns a dict, but I can't figure out how to do the same with __getitems__
.
class StackOverflowDataset(torch.utils.data.Dataset):
def __init__(self, data):
self._data = data
def __getitem__(self, idx):
return {'item': self._data[idx], 'whatever': idx*self._data[idx]+3}
def __getitems__(self, idxs):
return {'item': self._data[idxs], 'whatever': idxs*self._data[idxs]+3}
def __len__(self):
return len(self._data)
dataset = StackOverflowDataset(np.random.random(5))
for X in DataLoader(dataset, 2):
print(X)
break
If I comment out __getitems__
it works, but leaving it there raises a KeyError: 0
.
KeyError Traceback (most recent call last)
Cell In[182], line 15
12 return len(self._data)
14 dataset = StackOverflowDataset(np.random.random(5))
---> 15 for X in DataLoader(dataset, 2):
16 print(X)
17 break
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
671 def _next_data(self):
672 index = self._next_index() # may raise StopIteration
--> 673 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
674 if self._pin_memory:
675 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py:55, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
53 else:
54 data = self.dataset[possibly_batched_index]
---> 55 return self.collate_fn(data)
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py:317, in default_collate(batch)
256 def default_collate(batch):
257 r"""
258 Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
259
(...)
315 >>> default_collate(batch) # Handle `CustomType` automatically
316 """
--> 317 return collate(batch, collate_fn_map=default_collate_fn_map)
File ~/recommenders/venv/lib/python3.12/site-packages/torch/utils/data/_utils/collate.py:137, in collate(batch, collate_fn_map)
109 def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
110 r"""
111 General collate function that handles collection type of element within each batch.
112
(...)
135 for the dictionary of collate functions as `collate_fn_map`.
136 """
--> 137 elem = batch[0]
138 elem_type = type(elem)
140 if collate_fn_map is not None:
KeyError: 0
That's because pytorch tries to access data by index, starting from 0. Official documentation says:
Subclasses could also optionally implement getitems(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.
In other words, __getitems__
should return list, not a dict.