pythonpytorchinversemax-pooling

Pytorch: a similar process to reverse pooling and replicate padding?


I have a tensor A that has shape (batch_size, width, height). Assume that it has these values:

A = torch.tensor([[[0, 1],
                   [1, 0]]])

I am also given a number K that is a positive integer. Let K=2 in this case. I want to do a process that is similar to reverse pooling and replicate padding. This is the expected output:

B = torch.tensor([[[0, 0, 1, 1],
                   [0, 0, 1, 1],
                   [1, 1, 0, 0],
                   [1, 1, 0, 0]]])

Explanation: for each element in A, we expand it to the matrix of shape (K, K), and put it in the result tensor. We continue to do this with other elements, and let the stride between them equals to the kernel size (that is, K).

How can I do this in PyTorch? Currently, A is a binary mask, but it could be better if I can expand it to non-binary case.


Solution

  • Square expansion

    You can get your desired output by expanding twice:

    def dilate(t, k):
      x = t.squeeze()
      x = x.unsqueeze(-1).expand([*x.shape,k])
      x = x.unsqueeze(-1).expand([*x.shape,k])
      x = torch.cat([*x], dim=1)
      x = torch.cat([*x], dim=1)
      x = x.unsqueeze(0)
      return x
    
    B = dilate(A, k)
    

    Resizing / interpolating nearest

    If you don't mind corners potentially 'bleeding' in larger expansions (since it uses Euclidean as opposed to Manhattan distance when determining 'nearest' points to interpolate), a simpler method is to just resize:

    import torchvision.transforms.functional as F
    
    B = F.resize(A, A.shape[-1]*k)
    

    For completeness:

    MaxUnpool2d takes in as input the output of MaxPool2d including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero.