pythonpytorchconvolutionimagefilter

Separable convolutions in PyTorch (i.e. 2 1D-vector-tensor "traditional" convolutions)


I'm trying to implement an image filter in PyTorch that takes in two filters of shapes (1,3), (3,1) that build up a filter of (3,3). An example application of this is the Sobel filter or Gaussian blurring

I have a NumPy implementation ready, but PyTorch has a different way of working with convolutions that makes it hard to wrap my head around for more traditional applications such as this. How should I proceed?

def decomposed_conv2d(arr,x_kernel,y_kernel):
  """
  Apply two 1D kernels as a part of a 2D convolution.
  The kernels must be the decomposed from a 2D kernel
  that originally is intended to be convolved with the array.
  Inputs:
  - x_kernel: Column vector kernel, to be applied along the x axis (axis 0)
  - y_kernel: Row vector kernel, to be applied along the y axis (axis 1)
  """
  arr = np.apply_along_axis(lambda x: np.convolve(x, x_kernel, mode='same'), 0, arr)
  arr = np.apply_along_axis(lambda x: np.convolve(x, y_kernel, mode='same'), 1, arr)
  return arr

Gaussian blurring example:

ax = np.array([-1.,0.,1.])
stdev = 0.5
kernel = np.exp(-0.5 * np.square(ax) / np.square(stdev)) / (stdev * np.sqrt(2*np.pi))
decomposed_conv2d(np.arange(9).reshape((3,3)),kernel,kernel)
>>>array([[0.39126886, 1.24684326, 1.83682264],
       [2.86471127, 4.11155453, 4.48257929],
       [4.7279302 , 6.1004473 , 6.17348398]])

(Note: The total "energy" of this array may not be preserved, especially in small arrays like this because the convolution is discrete. It isn't that critical to this particular problem).

Attempting to do the same in PyTorch following this discussion yields an error:

... # define ax,stdev,kernel, etc.
arr_in = torch.arange(9).reshape(3,3) # for example
arr = arr_in.double().unsqueeze(0) # tried both axes and not unsqueezing as well
x_kernel = torch.from_numpy(kernel)
y_kernel = torch.from_numpy(kernel)

x_kernel = x_kernel.view(1,1,-1)
y_kernel = y_kernel.view(1,1,-1)
arr = F.conv1d(arr,x_kernel,padding=x_kernel.shape[2]//2).squeeze(0)
arr = F.conv1d(arr.transpose(0,1),y_kernel, padding=y_kernel.shape[2] // 2).squeeze(0).transpose(2,1).squeeze(1)

>>> RuntimeError: Given groups=1, weight of size [1, 1, 3], expected input[1, 3, 3] to have 1 channels, but got 3 channels instead

I've juggled with squeezes and unsqueezes so that the dimensions match but I still can't get it to do what I want. I just can't even get the first convolution done this way.


Solution

  • You can make your life a lot easier by using conv2d rather than conv1d.

    Although we use conv2d below, this is still a 1-d convolution (or rather, two 1-d convolutions) effectively, since we apply a 1×n kernel. Thus, we still have all benefits of a separable convolution (in particular, 2·n rather than n² multiplications per pixel for a kernel of length n).

    import numpy as np
    import torch
    from torch.nn.functional import conv2d
    np.set_printoptions(precision=3)  # For better legibility: show fewer float digits
    
    def decomposed_conv2d_np(arr, x_kernel, y_kernel):  # From the question
        arr = np.apply_along_axis(lambda x: np.convolve(x, x_kernel, mode='same'), 0, arr)
        arr = np.apply_along_axis(lambda x: np.convolve(x, y_kernel, mode='same'), 1, arr)
        return arr
    
    def decomposed_conv2d_torch(arr, x_kernel, y_kernel):  # Proposed
        arr = arr.unsqueeze(0).unsqueeze_(0)  # Make copy, make 4D for ``conv2d()``
        arr = conv2d(arr, weight=x_kernel.view(1, 1, -1, 1), padding='same')
        arr = conv2d(arr, weight=y_kernel.view(1, 1, 1, -1), padding='same')
        return arr.squeeze_(0).squeeze_(0)  # Make 2D again
    
    ax = np.array([-1.,0.,1.])
    stdev = 0.5
    kernel = np.exp(-0.5 * np.square(ax) / np.square(stdev)) / (stdev * np.sqrt(2 * np.pi))
    array = np.arange(9).reshape((3,3))
    
    print(result_np := decomposed_conv2d_np(array, kernel, kernel))
    # [[0.391 1.247 1.837]
    #  [2.865 4.112 4.483]
    #  [4.728 6.1   6.173]]
    
    array, kernel = torch.from_numpy(array).to(torch.float64), torch.from_numpy(kernel)
    print(result_torch := decomposed_conv2d_torch(array, kernel, kernel).numpy())
    # [[0.391 1.247 1.837]
    #  [2.865 4.112 4.483]
    #  [4.728 6.1   6.173]]
    
    assert np.allclose(result_np, result_torch)
    

    This solution is based on my answer to a related, earlier question that asked for an implementation of a Gaussian kernel in PyTorch.