I want to load image sequences of a fixed length into batches of the same size (for example sequence length = batch size = 7).
There are multiple directories each with images from a sequence with varying number of images. The sequences from different directories are not related to each other.
With my current code, I can process several subdirectories, but if there are not enough images in one directory to fill a batch, the remaining images are taken from the next directory. I would like to avoid this.
Instead, a batch should be discarded if there are not enough images in the current directory and instead the batch should only be filled with images from the next directory. This way, I want to avoid mixing unrelated image sequences in the same batch. If a directory does not have enough images to create even a single batch, it should be skipped completely.
So for example with a sequence length/batch size of 7:
I’m still learning, but I think this can be done with a costum batch sampler? Unfortunately, I have some problems with this. Maybe someone can help me find a solution.
This is my current code:
class MainDataset(Dataset):
def __init__(self, img_dir, use_folder_name=False):
self.gt_images = self._load_main_dataset(img_dir)
self.dataset_len = len(self.gt_images)
self.use_folder_name = use_folder_name
def __len__(self):
return self.dataset_len
def __getitem__(self, idx):
img_dir = self.gt_images[idx]
img_name = self._get_name(img_dir)
gt = self._load_img(img_dir)
# Skip non-image files
if gt is None:
return None
gt = torch.from_numpy(gt).permute(2, 0, 1)
return gt, img_name
def _get_name(self, img_dir):
if self.use_folder_name:
return img_dir.split(os.sep)[-2]
else:
return img_dir.split(os.sep)[-1].split('.')[0]
def _load_main_dataset(self, img_dir):
if not (os.path.isdir(img_dir)):
return [img_dir]
gt_images = []
for root, dirs, files in os.walk(img_dir):
for file in files:
if not is_valid_file(file):
continue
gt_images.append(os.path.join(root, file))
gt_images.sort()
return gt_images
def _load_img(self, img_path):
gt_image = io.imread(img_path)
gt_image_bd = getBitDepth(gt_image)
gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)
return gt_image
def is_valid_file(file_name: str):
# Check if the file has a valid image extension
valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']
for ext in valid_image_extensions:
if file_name.lower().endswith(ext):
return True
return False
sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)
While using a batch sampler might be a good idea to have a generic custom dataset that you can sample differently, I would prefer a straightforward approach.
I would construct a data structure in the init function that already contains all the image sequences you'll manipulate. The fact is that, currently, your Dataset class is lying as it says that the length of your dataset is equal to the number of image folders. This is not true as it depends on the number of images contained in the folder.
Currently, your dataset only returns one image at a time while you are expecting sequences.
Some information about the actual structure of the dataset is also missing from your question. Nevertheless, here is a proposal of Datatet class :
class MainDataset(Dataset):
def __init__(self, img_dir, use_folder_name=False, seq_len=7):
self.seq_len = seq_len
self.gt_images = self._load_main_dataset(img_dir)
self.use_folder_name = use_folder_name
def __len__(self):
return len(self.gt_images)
def __getitem__(self, idx):
label, sequence = self.gt_images[idx]
image_sequence = []
for image_path in sequence:
loaded_image = self._load_img(image_path)
loaded_image = torch.from_numpy(loaded_image).permute(2, 0, 1)
image_sequence.append(loaded_image)
all_sequence = torch.stack(image_sequence, dim=0)
# return a tensort of the sequence of images and the label
return all_sequence, label
def _get_name(self, img_dir):
if self.use_folder_name:
return img_dir.split(os.sep)[-2]
else:
return img_dir.split(os.sep)[-1].split('.')[0]
def _load_main_dataset(self, img_dir):
# I don't really know why you don't throw an exception here.
if not (os.path.isdir(img_dir)):
return [img_dir]
gt_images = []
# Why using walk ? What is the structure of the dataset ?
for root, dirs, files in os.walk(img_dir):
# This variable accumulates the images in the sequence
image_sequence = []
for file in files:
if not is_valid_file(file):
continue
img_path = os.path.join(root, file)
image_sequence.append(img_path)
if len(image_sequence) == self.seq_len:
sorted_sequence = image_sequence.sort()
label = self._get_name(sorted_sequence)
gt_images.append((label,sorted_sequence))
image_sequence = []
# Now gt_images is a list of tuples (label, sequence)
return gt_images
def _load_img(self, img_path):
gt_image = io.imread(img_path)
gt_image_bd = getBitDepth(gt_image)
gt_image = np.array(gt_image).astype(np.float32) / ((2 ** (gt_image_bd / 3)) - 1)
return gt_image
def is_valid_file(file_name: str):
# Check if the file has a valid image extension
valid_image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.tif']
for ext in valid_image_extensions:
if file_name.lower().endswith(ext):
return True
return False
sequence_data_store = MainDataset(img_dir=sdr_img_dir, use_folder_name=True)
sequence_loader = DataLoader(sequence_data_store, num_workers=0, pin_memory=False)