imagematplotlibpytorchtensorimshow

Looking for a torch.imshow() 'like' command


Say that i have a variable image (which is currently located on the gpu), sized [32,1,256,256] where 32 is the batch size, 1 is the amount of channels (gray scale).

Instead of ploting this:

plt.imshow(img[0,0,:,:].cpu().detach(),'gray');plt.show()

I wish i could do torch.imshow(img,8,'gray') and it will subplot 8 images from my batch is there any thing like that?


Solution

  • You are looking for torchvision.utils.make_grid: It will convert the [32, 1, 256,256] tensor into a grid of 32 images. You still need to use plt to actually plot the image grid to screen.