pythonpytorchpre-trained-model

TypeError when trying to display transformed images PyTorch


I have some trouble defining the transforms for images using PyTorch. Here you are the transforms I need:

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Resize((256, 256), interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(size=[224,224]),
    transforms.Normalize(mean, std),
    transforms.PILToTensor()
    
    
])

test_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(size=[224,224]),
    transforms.Normalize(mean, std),
    transforms.PILToTensor(),
])

Then, I create the loaders for the images of the dataset:

BATCH_SIZE = 32

trainset = torchvision.datasets.ImageFolder(root='CVPR2023_project_2_and_3_data/train/', loader=open_image)

trainset_classes = trainset.classes.copy()

subset_size = int(0.15*len(trainset))

validset = torchvision.datasets.ImageFolder(root='CVPR2023_project_2_and_3_data/train/', loader=open_image)

indices = torch.randperm(len(trainset))

valid_indices = indices[:subset_size]
train_indices = indices[subset_size:]

trainset = Subset(trainset, train_indices)
validset = Subset(validset, valid_indices)

# Apply transformations only to the training set
trainset.dataset.transform = train_transform
# Apply transformations to the validation set
validset.dataset.transform = test_transform

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True) # batch size of 1 because we have to crop in order to get all images to same size (64x64), also see pin_memory optin
validloader = torch.utils.data.DataLoader(validset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

testset = torchvision.datasets.ImageFolder(root='CVPR2023_project_2_and_3_data/test/', transform=test_transform, loader=Image.open)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

print(f'entire train folder: {len(trainset)}, entire test folder: {len(testset)}, splitted trainset: {len(trainset)},  splitted validset: {len(validset)}')

Then I load a pre-trained network and freeze all the layers but the last one:

model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)

model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=15, bias=True) #adapting to 15 classes

for param in model.parameters():
    param.requires_grad = False

for param in model.classifier[6].parameters():
    param.requires_grad = True

Then, I define a function for showing an image and I try to print one:

def imshow(img):
    npimg = img.numpy()
    plt.axis("off")
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

images, labels = next(iter(trainloader)) # <-- error

print(images[0])

This last piece of code does not work and the program crashes with the following error:

img should be Tensor Image. Got <class 'PIL.Image.Image'>

I have already tried to change the transforms' order but I get the inverse error, i.e.

img should be <class 'PIL.Image.Image'> Image. Got Tensor

Can anyone explain how I should solve this error?

Thank you in advance for your patience


Solution

  • I think the problem is that you shoul call transforms.PILToTensor before transforms.Normalize(mean, std) because this transformation doesn't support PIL format as input, see more here

    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Resize((256, 256), interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(size=[224,224]),
        transforms.PILToTensor(),
        transforms.Normalize(mean, std),
    
        
        
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((256, 256), interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(size=[224,224]),
        transforms.PILToTensor(),
        transforms.Normalize(mean, std),
        
    ])