I'm working on a pytorch project where my data is saved in zarr
.
Random access on zarr
is costly, but thanks to zarr
using a blockwise cache, iteration is really quick. To harness this fact, I use an IterableDataset
together with multiple workers:
class Data(IterableDataset):
def __init__(self, path, start=None, end=None):
super(Data, self).__init__()
store = zarr.DirectoryStore(path)
self.array = zarr.open(store, mode='r')
if start is None:
start = 0
if end is None:
end = self.array.shape[0]
assert end > start
self.start = start
self.end = end
def __iter__(self):
return islice(self.array, self.start, self.end)
The issue is that self.array
has on the order of 10e9
rows and for consecutive workers, as self.start
and self.end
naturally get bigger, creating the generators like itertools.islice(array, start, end)
takes a significant time out of my training/validation processes, because islice
still has to iterate over the unneeded elements until it gets to start
. Once a generator is created per each worker, this works like a charm, but to get there takes too long.
Is there a better way to create such a generator? Or maybe there's a smarter way to use zarr
in pytorch
?
Update: As of March 2021 this change has been merged into zarr.
I took a small dive into zarr and it looks like this will most easily be enabled from inside zarr. I have opened an issue here, in the meantime I made a fork of zarr that implements the function array.islice(start, end)
.
The dataset __iter__
method then looks like this:
def __iter__(self):
return self.array.islice(self.start, self.end)