I'm implementing a U-Net based architecture in PyTorch. At train time, I've patches of size 256x256
which doesn't cause any problem. However at test time, I've full HD images (1920x1080
). This is causing a problem during skip connections.
Downsampling 1920x1080
3 times gives 240x135
. If I downsample one more time, the resolution becomes 120x68
which when upsampled gives 240x136
. Now, I cannot concatenate these two feature maps. How can I solve this?
PS: I thought this is a fairly common problem, but I didn't get any solution or even mentioning of this problem anywhere on the web. Am I missing something?
It is a very common problem in segmentation networks where skip-connections are often involved in the decoding process. Networks usually (depending on the actual architecture) require input size that has side lengths as integer multiples of the largest stride (8, 16, 32, etc.).
There are two main ways:
I prefer (2) because (1) can cause small changes in the pixel level for all the pixels, leading to unnecessary blurriness. Note that we usually need to recover the original shape afterward in both methods.
My favorite code snippet for this task (symmetric padding for height/width):
import torch
import torch.nn.functional as F
def pad_to(x, stride):
h, w = x.shape[-2:]
if h % stride > 0:
new_h = h + stride - h % stride
else:
new_h = h
if w % stride > 0:
new_w = w + stride - w % stride
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pads = (lw, uw, lh, uh)
# zero-padding by default.
# See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
out = F.pad(x, pads, "constant", 0)
return out, pads
def unpad(x, pad):
if pad[2]+pad[3] > 0:
x = x[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
x = x[:,:,:,pad[0]:-pad[1]]
return x
A test snippet:
x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape
print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)
Output:
Original: torch.Size([4, 3, 1080, 1920])
Padded: torch.Size([4, 3, 1088, 1920])
Recovered: torch.Size([4, 3, 1080, 1920])
Reference: https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33