I'm trying to implement an image filter in PyTorch that takes in two filters of shapes (1,3), (3,1) that build up a filter of (3,3). An example application of this is the Sobel filter or Gaussian blurring
I have a NumPy implementation ready, but PyTorch has a different way of working with convolutions that makes it hard to wrap my head around for more traditional applications such as this. How should I proceed?
def decomposed_conv2d(arr,x_kernel,y_kernel):
"""
Apply two 1D kernels as a part of a 2D convolution.
The kernels must be the decomposed from a 2D kernel
that originally is intended to be convolved with the array.
Inputs:
- x_kernel: Column vector kernel, to be applied along the x axis (axis 0)
- y_kernel: Row vector kernel, to be applied along the y axis (axis 1)
"""
arr = np.apply_along_axis(lambda x: np.convolve(x, x_kernel, mode='same'), 0, arr)
arr = np.apply_along_axis(lambda x: np.convolve(x, y_kernel, mode='same'), 1, arr)
return arr
Gaussian blurring example:
ax = np.array([-1.,0.,1.])
stdev = 0.5
kernel = np.exp(-0.5 * np.square(ax) / np.square(stdev)) / (stdev * np.sqrt(2*np.pi))
decomposed_conv2d(np.arange(9).reshape((3,3)),kernel,kernel)
>>>array([[0.39126886, 1.24684326, 1.83682264],
[2.86471127, 4.11155453, 4.48257929],
[4.7279302 , 6.1004473 , 6.17348398]])
(Note: The total "energy" of this array may not be preserved, especially in small arrays like this because the convolution is discrete. It isn't that critical to this particular problem).
Attempting to do the same in PyTorch following this discussion yields an error:
... # define ax,stdev,kernel, etc.
arr_in = torch.arange(9).reshape(3,3) # for example
arr = arr_in.double().unsqueeze(0) # tried both axes and not unsqueezing as well
x_kernel = torch.from_numpy(kernel)
y_kernel = torch.from_numpy(kernel)
x_kernel = x_kernel.view(1,1,-1)
y_kernel = y_kernel.view(1,1,-1)
arr = F.conv1d(arr,x_kernel,padding=x_kernel.shape[2]//2).squeeze(0)
arr = F.conv1d(arr.transpose(0,1),y_kernel, padding=y_kernel.shape[2] // 2).squeeze(0).transpose(2,1).squeeze(1)
>>> RuntimeError: Given groups=1, weight of size [1, 1, 3], expected input[1, 3, 3] to have 1 channels, but got 3 channels instead
I've juggled with squeezes and unsqueezes so that the dimensions match but I still can't get it to do what I want. I just can't even get the first convolution done this way.
You can make your life a lot easier by using conv2d
rather than conv1d
.
Although we use conv2d
below, this is still a 1-d convolution (or rather, two 1-d convolutions) effectively, since we apply a 1×n kernel. Thus, we still have all benefits of a separable convolution (in particular, 2·n rather than n² multiplications per pixel for a kernel of length n).
import numpy as np
import torch
from torch.nn.functional import conv2d
np.set_printoptions(precision=3) # For better legibility: show fewer float digits
def decomposed_conv2d_np(arr, x_kernel, y_kernel): # From the question
arr = np.apply_along_axis(lambda x: np.convolve(x, x_kernel, mode='same'), 0, arr)
arr = np.apply_along_axis(lambda x: np.convolve(x, y_kernel, mode='same'), 1, arr)
return arr
def decomposed_conv2d_torch(arr, x_kernel, y_kernel): # Proposed
arr = arr.unsqueeze(0).unsqueeze_(0) # Make copy, make 4D for ``conv2d()``
arr = conv2d(arr, weight=x_kernel.view(1, 1, -1, 1), padding='same')
arr = conv2d(arr, weight=y_kernel.view(1, 1, 1, -1), padding='same')
return arr.squeeze_(0).squeeze_(0) # Make 2D again
ax = np.array([-1.,0.,1.])
stdev = 0.5
kernel = np.exp(-0.5 * np.square(ax) / np.square(stdev)) / (stdev * np.sqrt(2 * np.pi))
array = np.arange(9).reshape((3,3))
print(result_np := decomposed_conv2d_np(array, kernel, kernel))
# [[0.391 1.247 1.837]
# [2.865 4.112 4.483]
# [4.728 6.1 6.173]]
array, kernel = torch.from_numpy(array).to(torch.float64), torch.from_numpy(kernel)
print(result_torch := decomposed_conv2d_torch(array, kernel, kernel).numpy())
# [[0.391 1.247 1.837]
# [2.865 4.112 4.483]
# [4.728 6.1 6.173]]
assert np.allclose(result_np, result_torch)
This solution is based on my answer to a related, earlier question that asked for an implementation of a Gaussian kernel in PyTorch.