I have a dataset that contains images of brain tumoursl. I want to make a CNN to classify these images.What I have seen is a directory of images which is separated in “train” , “test” folders.
However, in this case, the dataset directory structure is as follows.
dataset_dir
|_____tumor_type_1
|_____tumor_type_2
|_____tumor_type_3
|_____no_tumor
Now, I want to make 3 dataloaders. ( a train_dataloader,a validation_dataloader & a test_dataloader.) Does anyone know how to do this in PyTorch without writing a custom script.
You can use torchvision's ImageFolder class (docs here), but you should first split your data to train/test/val in separate directories beforehand in this format:
├── train
│ ├── class1
| ├── image-1.jpg
│ ├── image-2.jpg
│ ├── class2
| ├── image-1.jpg
│ ├── image-2.jpg
├── val
│ ├── class1
| ├── image-1.jpg
│ ├── image-2.jpg
│ ├── class2
| ├── image-1.jpg
│ ├── image-2.jpg
├── test
│ ├── ...
...
to split the images randomly:
import os
import shutil
import random
test_split = 0.2
valid_split = 0.2
if not os.path.exists('./new_dataset_dir'):
os.mkdir('./new_dataset_dir')
os.mkdir('./new_dataset_dir/test')
os.mkdir('./new_dataset_dir/train')
os.mkdir('./new_dataset_dir/valid')
classes = os.listdir('./dataset_dir')
for c in classes:
images = os.listdir('./dataset_dir/' + c)
random.shuffle(images) # optional
num_images = len(images)
num_test = int(test_split * num_images)
num_valid = int(valid_split * num_images)
num_train = num_images - num_test - num_valid
os.mkdir('./new_dataset_dir/test/' + c)
os.mkdir('./new_dataset_dir/train/' + c)
os.mkdir('./new_dataset_dir/valid/' + c)
for idx, image in enumerate(images):
split = 'train' if idx < num_train else 'valid' if idx < num_train + num_valid else 'test'
shutil.move(f'./dataset_dir/{c}/{image}', f'./new_dataset_dir/{split}/{c}/{image}')
os.rmdir('./dataset_dir/' + c)
Then you can easily create the dataloaders using ImageFolder:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
train_dataset = ImageFolder(root='./new_dataset_dir/train')
val_dataset = ImageFolder(root='./new_dataset_dir/valid')
test_dataset = ImageFolder(root='./new_dataset_dir/test')
train_loader = DataLoader(train_dataset, ...)
valid_loader = DataLoader(val_dataset, ...)
test_loader = DataLoader(test_dataset, ...)