pythonpytorchinterpolation

PyTorch F.interpolate for many dimensions (e.g. 4D, 5D, 6D, 7D, ..., n-D)


I want to apply N-d interpolation to an (N+2)-d tensor for N>3.

import torch
import torch.nn.functional as F

x = torch.randn(1, 1, 2, 3, 4, 5, 6, 7)
output_size = (7, 6, 5, 4, 3, 2)
y = F.interpolate(x, size=output_size, mode="linear")

The above code gives the following error:

NotImplementedError: Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got linear)

Note that the first two dimensions are batch size and channels (B, C), and are thus not interpolated, as stated in the docs.

How do I apply N-d linear interpolation for N>3?


Solution

  • N-d linear interpolation is effectively the same as applying 1-D linear interpolation along each interpolated dimension in succession. Wikipedia gives the following example diagram for 2-D (bilinear) interpolation:

    Bilinear interpolation
    Apply linear interpolation along each dimension.

    Thus, one simple method is:

    def interpolate(input, size, scale_factor=None):
        assert input.ndim >= 3
        if scale_factor is not None:
            raise NotImplementedError
        output_shape = (*input.shape[:2], *size)
        assert len(input.shape) == len(output_shape)
        # Apply linear interpolation to each spatial dimension.
        for i in range(2, 2 + len(size)):
            input_tail = math.prod(input.shape[i + 1 :])
            input = F.interpolate(
                input.reshape(
                    input.shape[0], math.prod(input.shape[1:i]), input.shape[i], input_tail
                ),
                size=(output_shape[i], input_tail),
                mode="bilinear",
            ).reshape(*input.shape[:i], output_shape[i], *input.shape[i + 1 :])
        return input.reshape(output_shape)
    

    Usage:

    x = torch.randn(1, 1, 2, 3, 4, 5, 6, 7)
    output_size = (7, 6, 5, 4, 3, 2)
    y = interpolate(x, size=output_size)