I am trying to implement a 3d convolution using FFT with pyfftw. I used as base a code posted in another post in SO:
class CustomFFTConvolution(object):
def __init__(self, A, B, threads=1):
shape = (np.array(A.shape) + np.array(B.shape))-1
#shape=np.array(A.shape) - np.array(B.shape)+1
if np.iscomplexobj(A) and np.iscomplexobj(B):
self.fft_A_obj = pyfftw.builders.fftn(
A, s=shape, threads=threads)
self.fft_B_obj = pyfftw.builders.fftn(
B, s=shape, threads=threads)
self.ifft_obj = pyfftw.builders.ifftn(
self.fft_A_obj.get_output_array(), s=shape,
threads=threads)
else:
self.fft_A_obj = pyfftw.builders.rfftn(
A, s=shape, threads=threads)
self.fft_B_obj = pyfftw.builders.rfftn(
B, s=shape, threads=threads)
self.ifft_obj = pyfftw.builders.irfftn(
self.fft_A_obj.get_output_array(), s=shape,
threads=threads)
def __call__(self, A, B):
s1=np.array(A.shape)
s2=np.array(B.shape)
fft_padded_A = self.fft_A_obj(A)
fft_padded_B = self.fft_B_obj(B)
ret= self.ifft_obj(fft_padded_A * fft_padded_B)
return self._centered(ret, s1 - s2 + 1)
def _centered(self,arr, newshape):
# Return the center newshape portion of the array.
newshape = np.asarray(newshape)
currshape = np.array(arr.shape)
startind = (currshape - newshape) // 2
endind = startind + newshape
myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]
return arr[tuple(myslice)]
My data A has a shape of (931, 411, 806), and my filter B has a shape of (32, 32, 32). If I run this code using 24 threads in a 24 cores machine, the operation takes 263 seconds. Now if I run the same experiment on the same machine, but this time A has a shape of (806, 411, 931) just a swap of axis, the code takes only 16 seconds. What is the reason for this? Is there a rule of thumb to obtain the best performance? maybe padding one of the dimensions? Thanks!
Since padding is considered, could the padded size be increased to be even, or a multiple of small prime numbers? Opting for even sizes can divide the wall-clock time by 3.
Depending on the dimensions, some DFT algorithmss may not be available or efficient. For instance, one of the most effective algorithm to perform the DFT is the Cooley-Tuckey algorithm. It consist in dividing the DFT of a signal of composite size N=N1*N2 into N1 DTFs of size N2. As a consequence, it works better for composite sizes obtained by multiplying small prime factors (2, 3, 5, 7) for which dedicated efficient algorithms are provided in FFTW. From the documentation of FFTW:
For example, the standard FFTW distribution works most efficiently for arrays whose size can be factored into small primes (2, 3, 5, and 7), and otherwise it uses a slower general-purpose routine. If you need efficient transforms of other sizes, you can use FFTW’s code generator, which produces fast C programs (“codelets”) for any particular array size you may care about. For example, if you need transforms of size 513 = 19*33, you can customize FFTW to support the factor 19 efficiently.
Your padded sizes feature high prime factors:
931=>962=2*13*37
411=>442=2*13*17
806=>837=3*3*3*31
The padding could be extended to get closer to numbers featuring small prime numbers, such as 980, 448 and 864 for instance. Nevertheless, padding a 3D image results in a significant increase of the memory footprint, to the point that it is not always possible.
Why does changing the order of the dimensions change the computation time? The difference could be due to the input array being real. Hence, a R2C DFT is performed over one of the dimension, then a C2C over the second and the thrid to compute the 3D DFT. If the size of the first dimension to be transformed is even, the R2C transform can be turned into a complex DFT of half the size, as shown here. This trick does not work for odd size. As a consequence, some fast algorithms likely become available as 962 and 837 are flipped.
Here is a code to test it:
import pyfftw
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
from timeit import default_timer as timer
def listofgoodsizes():
listt=[]
p2=2
for i2 in range(11):
p3=1
for i3 in range(7):
p5=1
for i5 in range(2):
listt.append(p2*p3*p5)
p5*=5
p7=1
for i7 in range(2):
listt.append(p2*p3*p7)
p7*=7
p3*=3
p2*=2
listt.sort()
return listt
def getgoodfftwsize(n,listt):
for i in range(len(listt)):
if listt[i]>=n:
return listt[i]
return n
def timea3DR2CDFT(n,m,p):
bb = pyfftw.empty_aligned((n,m, p), dtype='float64')
bf= pyfftw.empty_aligned((n,m, (p/2+1)), dtype='complex128')
pyfftw.config.NUM_THREADS = 1 #multiprocessing.cpu_count()
fft_object_b = pyfftw.FFTW(bb, bf,axes=(0,1,2))
print n,m,p
start = timer()
fft_object_b(bb)
end = timer()
print end - start
#three prime numbers !
n=3*37
m=241
p=5*19
timea3DR2CDFT(n,m,p)
# to even size :
neven=2*((n+1)/2)
meven=2*((m+1)/2)
peven=2*((p+1)/2)
timea3DR2CDFT(neven,meven,peven)
#to nearest multiple of prime
listt=listofgoodsizes()
ngood=getgoodfftwsize(n,listt)
mgood=getgoodfftwsize(m,listt)
pgood=getgoodfftwsize(p,listt)
timea3DR2CDFT(ngood,mgood,pgood)
On my computer, it prints:
111 241 95
0.180601119995
112 242 96
0.0560319423676
112 252 96
0.0564918518066