pytorchconvolutiongaussianblur

Implementing a 3D gaussian blur using separable 2D convolutions in pytorch


I'm trying to implement a gaussian-like blurring of a 3D volume in pytorch. I can do a 2D blur of a 2D image by convolving with a 2D gaussian kernel easy enough, and the same approach seems to work for 3D with a 3D gaussian kernel. However, it is very slow in 3D (especially with larger sigmas/kernel sizes). I understand this can also be done instead by convolving 3 times with the 2D kernel which should be much faster, but I can't get this to work. My test case is below.

import torch
import torch.nn.functional as F

VOL_SIZE = 21


def make_gaussian_kernel(sigma):
    ks = int(sigma * 5)
    if ks % 2 == 0:
        ks += 1
    ts = torch.linspace(-ks // 2, ks // 2 + 1, ks)
    gauss = torch.exp((-(ts / sigma)**2 / 2))
    kernel = gauss / gauss.sum()

    return kernel


def test_3d_gaussian_blur(blur_sigma=2):
    # Make a test volume
    vol = torch.zeros([VOL_SIZE] * 3)
    vol[VOL_SIZE // 2, VOL_SIZE // 2, VOL_SIZE // 2] = 1

    # 3D convolution
    vol_in = vol.reshape(1, 1, *vol.shape)
    k = make_gaussian_kernel(blur_sigma)
    k3d = torch.einsum('i,j,k->ijk', k, k, k)
    k3d = k3d / k3d.sum()
    vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)

    # Separable 2D convolution
    vol_in = vol.reshape(1, *vol.shape)
    k2d = torch.einsum('i,j->ij', k, k)
    k2d = k2d / k2d.sum()
    k2d = k2d.expand(VOL_SIZE, 1, *k2d.shape)
    for i in range(3):
        vol_in = vol_in.permute(0, 3, 1, 2)
        vol_in = F.conv2d(vol_in, k2d, stride=1, padding=len(k) // 2, groups=VOL_SIZE)
    vol_3d_sep = vol_in

    torch.allclose(vol_3d, vol_3d_sep)  # --> False

Any help would be very much appreciated!


Solution

  • You theoreticaly can compute the 3d-gaussian convolution using three 2d-convolutions, but that would mean you have to reduce the size of the 2d-kernel, as you're effectively convolving in each direction twice.

    But computationally more efficient (and what you usually want) is a separation into 1d-kernels. I changed the second part of your function to implement this. (And I must say I really liked your permutation-based appraoch!) Since you're using a 3d volume you can't really use the conv2d or conv1d functions well, so the best thing is really just using conv3d even if you're just computing 1d-convolutions.

    Note that allclose uses a threshold of 1e-8 which we do not reach with this method, probably due to cancellation errors.

    def test_3d_gaussian_blur(blur_sigma=2):
        # Make a test volume
        vol = torch.randn([VOL_SIZE] * 3) # using something other than zeros
        vol[VOL_SIZE // 2, VOL_SIZE // 2, VOL_SIZE // 2] = 1
    
        # 3D convolution
        vol_in = vol.reshape(1, 1, *vol.shape)
        k = make_gaussian_kernel(blur_sigma)
        k3d = torch.einsum('i,j,k->ijk', k, k, k)
        k3d = k3d / k3d.sum()
        vol_3d = F.conv3d(vol_in, k3d.reshape(1, 1, *k3d.shape), stride=1, padding=len(k) // 2)
    
        # Separable 1D convolution
        vol_in = vol[None, None, ...]
        # k2d = torch.einsum('i,j->ij', k, k)
        # k2d = k2d / k2d.sum() # not necessary if kernel already sums to zero, check:
        # print(f'{k2d.sum()=}')
        k1d = k[None, None, :, None, None]
        for i in range(3):
            vol_in = vol_in.permute(0, 1, 4, 2, 3)
            vol_in = F.conv3d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0, 0))
        vol_3d_sep = vol_in
        print((vol_3d- vol_3d_sep).abs().max()) # something ~1e-7
        print(torch.allclose(vol_3d, vol_3d_sep)) # allclose checks if it is around 1e-8
    

    Addendum: If you really want to abuse conv2d to process the volumes you can try

    # separate 3d kernel into 1d + 2d
    vol_in = vol[None, None, ...]
    k2d = torch.einsum('i,j->ij', k, k)
    k2d = k2d.expand(VOL_SIZE, 1, len(k), len(k))
    # k2d = k2d / k2d.sum() # not necessary if kernel already sums to zero, check:
    # print(f'{k2d.sum()=}')
    k1d = k[None, None, :, None, None]
    vol_in = F.conv3d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0, 0))
    vol_in = vol_in[0, ...]
    # abuse conv2d-groups argument for volume dimension, works only for 1 channel volumes
    vol_in = F.conv2d(vol_in, k2d, stride=1, padding=(len(k) // 2, len(k) // 2), groups=VOL_SIZE)
    vol_3d_sep = vol_in
    

    Or using exclusively conv2d you could do:

    # separate 3d kernel into 1d + 2d
    vol_in = vol[None,  ...]
    # 1d kernel
    k1d = k[None, None, :,  None]
    k1d = k1d.expand(VOL_SIZE, 1, len(k), 1)
    # 2d kernel
    k2d = torch.einsum('i,j->ij', k, k)
    k2d = k2d.expand(VOL_SIZE, 1, len(k), len(k))
    vol_in = vol_in.permute(0, 2, 1, 3)
    vol_in = F.conv2d(vol_in, k1d, stride=1, padding=(len(k) // 2, 0), groups=VOL_SIZE)
    vol_in = vol_in.permute(0, 2, 1, 3)
    vol_in = F.conv2d(vol_in, k2d, stride=1, padding=(len(k) // 2, len(k) // 2), groups=VOL_SIZE)
    vol_3d_sep = vol_in
    

    These should still be faster than three consecutive 2d convolutions.