pytorchcomputer-visionenumerate

Pytorch: How to get the first N item from dataloader


There are 3000 pictures in my list, but I only want the first N of them, like 1000, for training. I wonder how can I achieve this by changing the loop code:

for (image, label) in enumerate(train_loader):


Solution

  • for (image, label) in list(enumerate(train_loader))[:1000]:
    

    This is not a good way to partition training and validation data though. First, the dataloader class supports lazy loading (examples are not loaded into memory until needed) whereas casting as a list will require all data to be loaded into memory, likely triggering an out-of-memory error. Second, this may not always return the same 1000 elements if the dataloader has shuffling. In general, the dataloader class does not support indexing so is not really suitable for selecting a specific subset of our dataset. Casting as a list works around this but at the expense of the useful attributes of the dataloader class.

    Best practice is to use a separate data.dataset object for the training and validation partitions, or at least to partition the data in the dataset rather than relying on stopping the training after the first 1000 examples. Then, create a separate dataloader for the training partition and validation partition.