pythonpytorchtorchvision

iteration over torch DataSet not loading multiple targets


I am trying load a dataset from files and train an AI model on it. For some reason when I use for iamges, targets in dataloader in my main it loads the targets like:

{
    'image_id':[all image ids],
    'keypoints':[all keypoint lists],
    'labels':[all label lists],
    'boxes':[all bboxes]
}

instead of

[
    {
        'image_id':image_id of first sample,
        'keypoints':[list of keypoints of first sample],
        'labels':[list of labels of first sample],
        'boxes':bbox of first sample
    },
...
    {
        'image_id':image_id of fourth sample,
        'keypoints':[list of keypoints of fourth sample],
        'labels':[list of labels of fourth sample],
        'boxes':bbox of fourth sample
    }
]

This is my __getitem__ function:

def __getitem__(self, idx):
    annotation = self.annotations[idx]
    image_id = annotation['image_id']
    file_name = annotation['file_name']
    image_path = f"{self.images_dir}/{file_name}"
    image = Image.open(image_path).convert("RGB")
    bbox = np.array(annotation['bbox'])
    keypoints = np.array([[ann["x"],ann["y"]] for ann in annotation["keypoints"]])
    labels = np.array([kp_num[ann["name"]] for ann in annotation["keypoints"]])
    target = {
        "image_id":image_id,
        "keypoints": torch.tensor(keypoints, dtype=torch.float32),
        "labels": torch.tensor(labels, dtype=torch.int64),
        "boxes":torch.tensor(bbox, dtype=torch.int)
    }
    
    if self.transform:
        image = self.transform(image)

    return image, target

I hoped for it to return a list of targets but it returns a dict of lists. I tried putting target in the return statement in a list with one element but it just returned a list with a single entry with all info instead of a list with batch_size many targets. I am using the torch.utils.data DataLoader class.

EDIT: solved it, I just implemented a custom_collate_fn like this:

def custom_collate_fn(batch):
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch] 
    return images, targets

Solution

  • The default collate_fn would behave like stacking up the each keys of your dictionary. Assuming that batch_size=N, each element of target['key'] would be [N, ...]. (Recommended behavior for future compatibility, memory allocations, and optimization on browsing datasets.)

    If you really need to pop out the values on your dictionary, try below code.

    def my_collate_fn(batch_sample):
        image = []
        target = []
        for sample in batch_sample:
            # The `sample` is returns of your __getitem__()
            image.append(sample[0])
            target.append(sample[1])
    
        return (torch.stack(image).contiguous(),
                target) # The `target` is no longer a tesnor as the dict cannot be stacked itself.
    # --- on your later DataLoader init --- #
    loader = torch.utils.data.DataLoader(collate_fn=my_collate_fn, **kwargs)
    

    BTW, mentioned with above reasons, I suggest you use the default collate_fn instead.