The PyTorch DataLoader turns datasets into iterables. I already have a generator which yields data samples that I want to use for training and testing. The reason I use a generator is because the total number of samples is too large to store in memory. I would like to load the samples in batches for training.
What is the best way to do this? Can I do it without a custom DataLoader? The PyTorch dataloader doesn't like taking the generator as input. Below is a minimal example of what I want to do, which produces the error "object of type 'generator' has no len()".
import torch
from torch import nn
from torch.utils.data import DataLoader
def example_generator():
for i in range(10):
yield i
BATCH_SIZE = 3
train_dataloader = DataLoader(example_generator(),
batch_size = BATCH_SIZE,
shuffle=False)
print(f"Length of train_dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
I am trying to take the data from an iterator and take advantage of the functionality of the PyTorch DataLoader. The example I gave is a minimal example of what I would like to achieve, but it produces an error.
Edit: I want to be able to use this function for complex generators in which the len is not known in advance.
PyTorch's DataLoader
actually has official support for an iterable dataset, but it just has to be an instance of a subclass of torch.utils.data.IterableDataset
:
An iterable-style dataset is an instance of a subclass of IterableDataset that implements the
__iter__()
protocol, and represents an iterable over data samples
So your code would be written as:
from torch.utils.data import IterableDataset
class MyIterableDataset(IterableDataset):
def __init__(self, iterable):
self.iterable = iterable
def __iter__(self):
return iter(self.iterable)
...
BATCH_SIZE = 3
train_dataloader = DataLoader(MyIterableDataset(example_generator()),
batch_size = BATCH_SIZE,
shuffle=False)