pythonmachine-learningpytorchcomputer-visionartificial-intelligence

Pytorch Unfold and Fold: How do I put this image tensor back together again?


I am trying to filter a single channel 2D image of size 256x256 using unfold to create 16x16 blocks with an overlap of 8. This is shown below:

# I = [256, 256] image
kernel_size = 16
stride = bx/2
patches = I.unfold(1, kernel_size, 
int(stride)).unfold(0, kernel_size, int(stride)) # size = [31, 31, 16, 16]

 

I have started to attempt to put the image back together with fold but I’m not quite there yet. I’ve tried to use view to get the image to ‘fit’ the way it’s supposed to but I don’t see how this would preserve the original image. Perhaps I’m overthinking this.

# patches.shape = [31, 31, 16, 16]
patches = = filt_data_block.contiguous().view(-1, kernel_size*kernel_size) # [961, 256]
patches = patches.permute(1, 0) # size = [951, 256]

Any help would be greatly appreciated. Thanks very much.


Solution

  • A slightly less elegant solution than that proposed by Gil:

    I took inspiration from this post on the Pytorch forums, formatting my image tensor to be of standard shape B x C x H x W (1 x 1 x 256 x 256). Unfolding:

    # CREATE THE UNFOLDED IMAGE SLICES
    I = image           # shape [256, 256]
    kernel_size = bx    #shape [16]
    stride = int(bx/2)  #shape [8]
    I2 = I.unsqueeze(0).unsqueeze(0) #shape [1, 1, 256, 256]
    patches2 = I2.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride)
    #shape [1, 1, 31, 31, 16, 16]
    

    Following this, I do some transforms and filtering to my tensor stack. Before doing this I apply a cosine window and normalise:

    # NORMALISE AND WINDOW
    Pvv = torch.mean(torch.pow(win, 2))*torch.numel(win)*(noise_std**2)
    Pvv = Pvv.double()
    mean_patches = torch.mean(patches2, (4, 5), keepdim=True)
    mean_patches = mean_patches.repeat(1, 1, 1, 1, 16, 16)
    window_patches = win.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 31, 31, 1, 1)
    zero_mean = patches2 - mean_patches
    windowed_patches = zero_mean * window_patches
    
    #SOME FILTERING ....
    
    #ADD MEAN AND WINDOW BEFORE FOLDING BACK TOGETHER.
    filt_data_block = (filt_data_block + mean_patches*window_patches) * window_patches
    

    The above code works for me, but a mask would be more simple. Next, I prepare my tensor of form [1, 1, 31, 31, 16, 16] to be transformed back into the original [1, 1, 256, 256]:

    # REASSEMBLE THE IMAGE USING FOLD
    patches = filt_data_block.contiguous().view(1, 1, -1, kernel_size*kernel_size)
    patches = patches.permute(0, 1, 3, 2)
    patches = patches.contiguous().view(1, kernel_size*kernel_size, -1)
    IR = F.fold(patches, output_size=(256, 256), kernel_size=kernel_size, stride=stride)
    IR = IR.squeeze()
    

    This allowed me to create an overlapping sliding window and seamlessly stitch the image back together. Cutting out the filtering makes for an identical image.