pythonpytorchimage-segmentationtorchvision

"RuntimeError: Cannot access data pointer of Tensor that doesn't have storage" when using PyTorch vmap


Let's say I have a segmentation model (model) and I want to batch transform its predictions to pillow images. And, for simplicity, let's say everything is done on CPU (no GPU involved).

If I do:

import torch
from torchvision.transforms import ToPILImage

transform = ToPILImage()
model.eval()
for i, (x, y) in enumerate(dataloader):
    y_hat = torch.sigmoid(model(x))  # returns a tensor (batch_size, 1, H, W)
    y_hat = (y_hat > 0.5).float()
    img = transform(y_hat)

I get:

ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

Fair enough. Let me try using vmap to transform it as a batch:

import torch
from torchvision.transforms import ToPILImage

transform = ToPILImage()
batch_transform = torch.func.vmap(transform)
model.eval()
for i, (x, y) in enumerate(dataloader):
    y_hat = torch.sigmoid(model(x))  # returns a tensor (batch_size, 1, H, W)
    y_hat = (y_hat > 0.5).float()
    img = batch_transform(y_hat)

That produces the following error:

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Why does this behave this way? Does it have anything to do with the function I've chosen to vmap? I've followed the pattern that's in the documentation and this should work. How can I perform this operation to a batch of images?


Solution

  • As the error message suggests, the ToPILImage transform operates on tensors that are either 2D (H,W) or 4D (C, H, W). This means you have to iterate over the batch elements and apply the transform:

    imgs = [transform(t) for t in y_hat]
    

    Alternatively, you can use torchvision.utils.make_grid to construct a grid from a list of tensors:

    img = transform(make_grid(y_hat))
    

    The convenient torchvision.utils.save_image utility function is there to combine make_grid, the PIL.Image conversion, and saving to file system in one call:

    save_image(y_hat, 'pred.jpg')