I want to convert images to tensor using torchvision.transforms.ToTensor()
. After processing, I printed the image but the image was not right. Here is my code:
trans = transforms.Compose([
transforms.ToTensor()])
demo = Image.open(img)
demo_img = trans(demo)
demo_array = demo_img.numpy()*255
print(Image.fromarray(demo_array.astype(np.uint8)))
The original image is:
After processing, it looks like:
It seems that the problem is with the channel axis.
If you look at torchvision.transforms
docs, especially on ToTensor()
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
So once you perform the transformation and return to numpy.array
your shape is: (C, H, W) and you should change the positions, you can do the following:
demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)
This will transform the array to shape (H, W, C) and then when you return to PIL and show it will be the same image.
So in total:
import numpy as np
from PIL import Image
from torchvision import transforms
trans = transforms.Compose([transforms.ToTensor()])
demo = Image.open(img)
demo_img = trans(demo)
demo_array = np.moveaxis(demo_img.numpy()*255, 0, -1)
print(Image.fromarray(demo_array.astype(np.uint8)))