pythonpython-3.xinterpolationtorchbilinear-interpolation

What is the difference between torch.nn.functional.grid_sample and torch.nn.functional.interpolate?


Let's say I have an image I want to downsample to half its resolution via either grid_sample or interpolate from the torch.nn.functional library. I select mode ='bilinear' for both cases.

For grid_sample, I'd do the following:

dh = torch.linspace(-1,1, h/2)
dw = torch.linspace(-1,1, w/2)
mesh, meshy = torch.meshgrid((dh,dw))
grid = torch.stack.((meshy,meshx),2)
grid = grid.unsqueeze(0) #add batch dim

x_downsampled = torch.nn.functional.grid_sample(img, grid, mode='bilinear')

For interpolate, I'd do:

x_downsampled = torch.nn.functional.interpolate(img, size=(h/2,w/2), mode='bilinear')

What do both methods differently? Which one is better for my example?


Solution

  • Interploate is limited to scaling or resizing the input.

    Grid_sample is more flexible and can perform any form of warping of the input grid. The warping grid can specify the position where each element in the input will end up in the output.

    In simple terms, interpolate does not provide the ability to change the ordering of elements in the input grid (an element to the right of another element will still be to the right after interpolation). Grid_sample is capable of changing the order of the elements and any arbitrary order can be achieved based on warping grid that is passed to the function.

    A simple 2D illustration to show grid_sample in action:

    input = [[10, 20], [30, 40]]
    grid (contains coordinates) = [[(1,0),(0,0)], [(1,0),(0,0)]]
    
    output will be: [[20, 10], [30, 40]]
    

    For your example, while both can work, interpolate would be the preferred way to go.