pythonpytorchpython-imaging-librarydata-science

Pytorch transform.ToTensor() changes image


I want to convert images to tensor using torchvision.transforms.ToTensor(). After processing, I printed the image but the image was not right. Here is my code:

trans = transforms.Compose([
    transforms.ToTensor()])

demo = Image.open(img) 
demo_img = trans(demo)
demo_array = demo_img.numpy()*255
print(Image.fromarray(demo_array.astype(np.uint8)))

The original image is:

original image

After processing, it looks like:

after processing


Solution

  • It seems that the problem is with the channel axis.

    If you look at torchvision.transforms docs, especially on ToTensor()

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]

    So once you perform the transformation and return to numpy.array your shape is: (C, H, W) and you should change the positions, you can do the following:

    demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)
    

    This will transform the array to shape (H, W, C) and then when you return to PIL and show it will be the same image.

    So in total:

    import numpy as np
    from PIL import Image
    from torchvision import transforms
    
    trans = transforms.Compose([transforms.ToTensor()])
    
    demo = Image.open(img) 
    demo_img = trans(demo)
    demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)
    print(Image.fromarray(demo_array.astype(np.uint8)))