The title says it all. An iterable dataset with a multi-worker dataloader yields more batches than it should (seems that each worker yields all the batches separately). Here is an MWE:
import torch
class ToyDataset(torch.utils.data.IterableDataset):
def __iter__(self):
data = torch.arange(len(self))
yield from data
def __len__(self):
return 386
dataset = ToyDataset()
loader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=2)
print(len(loader), len(list(loader))) # 2 4
Is it something I'm missing? Is this a bug in pytorch (though it seems highly unlikely)? And most importantly, is there any way around this?
I also created an issue on the pytorch discuss forums, however it didn't get much attention.
This behavior is expected and explained in the IterableDataset
documentation:
When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoader iterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s iter() method or the DataLoader ‘s worker_init_fn option to modify each copy’s behavior.
The linked documentation page also gives two examples for using an IterableDataset
with multiple workers. One using worker info in the __iter__
method of the dataset, the other using the worker_init_fn
for the dataloader.
As a simple example:
import torch
import math
class ToyDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
iter_start = self.start
iter_end = self.end
else:
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
data = torch.arange(iter_start, iter_end)
yield from data
dataset = ToyDataset(0, 386)
loader = torch.utils.data.DataLoader(dataset, batch_size=256, num_workers=2)
print(len(list(loader)))
> 2