
Pytorch load data in mini batches

I have a folder of images as such

|   |__img1_b01.tiff
|   |__img1_b02.tiff
|   |__img1_b03.tiff
|   |__img1_b04.tiff
|   |__img1_b05.tiff
|   |__img2_b02.tiff
|   |__img2_b02.tiff
|   |__img2_b03.tiff
|   |__img2_b04.tiff
|   |__img2_b05.tiff
|.. img1000  

Each folder represents an image. Each file in the folders represents a band channel of the image.

Hence each image would have a

I am stuck writting the pytorch custom dataloader to load in batches of 64

So I could have Feature batch shape: torch.Size([64,5, 256, 256])

I have tried the following code

from torchvision import datasets, transforms
from torch.utils import data

dataset = datasets.ImageFolder(root = Images/, 
            transform = transforms.ToTensor())

loader = data.DataLoader(dataset, batch_size = 64, shuffle = True)

But it is not giving the results I want which is Feature batch shape: torch.Size([64, 5, 256, 256])


  • Using datasets.ImageFolder will make PyTorch treat each "band" image independently and treat the folder names (e.g., img1, img2...) as "class labels".
    In order to load 5 image files as different bands/channels of the same image, you'll need to write your own custom Dataset.

    This custom Dataset may look something like this:

    import torch
    import os
    from PIL import Image
    import numpy as np
    class MultiBandDataset(
      def __init__(self, root, num_bands):
        self.root = root
        self.num_bands = num_bands
        self.imgs = os.listdir(root)  # all `imgNN` folders
      def __len__(self):
        return len(self.imgs)  # number of images = number of subfolders
      def __getitem__(self, index):
        multi_band = []
        # get the subfolder
        subf = os.path.join(self.root, self.imgs[index])
        for band in range(self.num_bands):
          b =, f'{self.imgs[index]}_b{band+1:02d}.tiff')).convert("F")  # make sure you are reading a single channel from each image. you need to verify this part.
          multi_band.append(np.array(b).astype(np.float32)[None,...])  # add singleton channel dimension
        return np.concatenate(numti_band, axis=0)

    Note that you would probably need to re-implement augmentations as well.