I am trying load a dataset from files and train an AI model on it.
For some reason when I use for iamges, targets in dataloader
in my main it loads the targets like:
{
'image_id':[all image ids],
'keypoints':[all keypoint lists],
'labels':[all label lists],
'boxes':[all bboxes]
}
instead of
[
{
'image_id':image_id of first sample,
'keypoints':[list of keypoints of first sample],
'labels':[list of labels of first sample],
'boxes':bbox of first sample
},
...
{
'image_id':image_id of fourth sample,
'keypoints':[list of keypoints of fourth sample],
'labels':[list of labels of fourth sample],
'boxes':bbox of fourth sample
}
]
This is my __getitem__
function:
def __getitem__(self, idx):
annotation = self.annotations[idx]
image_id = annotation['image_id']
file_name = annotation['file_name']
image_path = f"{self.images_dir}/{file_name}"
image = Image.open(image_path).convert("RGB")
bbox = np.array(annotation['bbox'])
keypoints = np.array([[ann["x"],ann["y"]] for ann in annotation["keypoints"]])
labels = np.array([kp_num[ann["name"]] for ann in annotation["keypoints"]])
target = {
"image_id":image_id,
"keypoints": torch.tensor(keypoints, dtype=torch.float32),
"labels": torch.tensor(labels, dtype=torch.int64),
"boxes":torch.tensor(bbox, dtype=torch.int)
}
if self.transform:
image = self.transform(image)
return image, target
I hoped for it to return a list of targets but it returns a dict of lists. I tried putting target in the return statement in a list with one element but it just returned a list with a single entry with all info instead of a list with batch_size many targets. I am using the torch.utils.data DataLoader class.
EDIT: solved it, I just implemented a custom_collate_fn like this:
def custom_collate_fn(batch):
images = [item[0] for item in batch]
targets = [item[1] for item in batch]
return images, targets
The default collate_fn
would behave like stacking up the each keys of your dictionary. Assuming that batch_size=N, each element of target['key']
would be [N, ...]. (Recommended behavior for future compatibility, memory allocations, and optimization on browsing datasets.)
If you really need to pop out the values on your dictionary, try below code.
def my_collate_fn(batch_sample):
image = []
target = []
for sample in batch_sample:
# The `sample` is returns of your __getitem__()
image.append(sample[0])
target.append(sample[1])
return (torch.stack(image).contiguous(),
target) # The `target` is no longer a tesnor as the dict cannot be stacked itself.
# --- on your later DataLoader init --- #
loader = torch.utils.data.DataLoader(collate_fn=my_collate_fn, **kwargs)
BTW, mentioned with above reasons, I suggest you use the default collate_fn
instead.