pythonnumpynumpy-slicingpyfftw

Efficiently using 1-D pyfftw on small slices of a 3-D numpy array


I have a 3D data cube of values of size on the order of 10,000x512x512. I want to parse a window of vectors (say 6) along dim[0] repeatedly and generate the fourier transforms efficiently. I think I'm doing an array copy into the pyfftw package and it's giving me massive overhead. I'm going over the documentation now since I think there is an option I need to set, but I could use some extra help on the syntax.

This code was originally written by another person with numpy.fft.rfft and accelerated with numba. But the implementation wasn't working on my workstation so I re-wrote everything and opted to go for pyfftw instead.

import numpy as np
import pyfftw as ftw
from tkinter import simpledialog
from math import ceil
import multiprocessing

ftw.config.NUM_THREADS = multiprocessing.cpu_count()
ftw.interfaces.cache.enable()

def runme():
    # normally I would load a file, but for Stack Overflow, I'm just going to generate a 3D data cube so I'll delete references to the binary saving/loading functions:
    # load the file
    dataChunk = np.random.random((1000,512,512))
    numFrames = dataChunk.shape[0]
    # select the window size
    windowSize = int(simpledialog.askstring('Window Size',
        'How many frames to demodulate a single time point?'))
    numChannels = windowSize//2+1
    # create fftw arrays
    ftwIn = ftw.empty_aligned(windowSize, dtype='complex128')
    ftwOut = ftw.empty_aligned(windowSize, dtype='complex128')
    fftObject = ftw.FFTW(ftwIn,ftwOut)
    # perform DFT on the data chunk
    demodFrames = dataChunk.shape[0]//windowSize
    channelChunks = np.zeros([numChannels,demodFrames,
        dataChunk.shape[1],dataChunk.shape[2]])
    channelChunks = getDFT(dataChunk,channelChunks,
        ftwIn,ftwOut,fftObject,windowSize,numChannels)
    return channelChunks          

def getDFT(data,channelOut,ftwIn,ftwOut,fftObject,
        windowSize,numChannels):
    frameLen = data.shape[0]
    demodFrames = frameLen//windowSize
    for yy in range(data.shape[1]):
        for xx in range(data.shape[2]):
            index = 0
            for i in range(0,frameLen-windowSize+1,windowSize):
                ftwIn[:] = data[i:i+windowSize,yy,xx]
                fftObject() 
                channelOut[:,index,yy,xx] = 2*np.abs(ftwOut[:numChannels])/windowSize
                index+=1
    return channelOut

if __name__ == '__main__':
    runme()

What happens is I get a 4D array; the variable channelChunks. I am saving out each channel to a binary (not included in the code above, but the saving part works fine).

This process is for a demodulation project we have, the 4D data cube channelChunks is then parsed into eval(numChannel) 3D data cubes (movies) and from that we are able to separate a movie by color given our experimental set up. I was hoping I could circumvent writing a C++ function that calls the fft on the matrix via pyfftw.

Effectively, I am taking windowSize=6 elements along the 0 axis of dataChunk at a given index of 1 and 2 axis and performing a 1D FFT. I need to do this throughout the entire 3D volume of dataChunk to generate the demodulated movies. Thanks.


Solution

  • The FFTW advanced plans can be automatically built by pyfftw. The code could be modified in the following way:

    Here is the modifed code. I also, decreased the number of frame to 100, set the seed of the random generator to check that the outcome is not modifed and commented tkinter. The size of the window can be set to a power of two, or a number made by multiplying 2,3,5 or 7, so that the Cooley-Tuckey algorithm can be efficiently applied. Avoid large prime numbers.

    import numpy as np
    import pyfftw as ftw
    #from tkinter import simpledialog
    from math import ceil
    import multiprocessing
    import time
    
    
    ftw.config.NUM_THREADS = multiprocessing.cpu_count()
    ftw.interfaces.cache.enable()
    ftw.config.PLANNER_EFFORT = 'FFTW_MEASURE'
    
    def runme():
        # normally I would load a file, but for Stack Overflow, I'm just going to generate a 3D data cube so I'll delete references to the binary saving/loading functions:
        # load the file
        np.random.seed(seed=42)
        dataChunk = np.random.random((100,512,512))
        numFrames = dataChunk.shape[0]
        # select the window size
        #windowSize = int(simpledialog.askstring('Window Size',
        #    'How many frames to demodulate a single time point?'))
        windowSize=32
        numChannels = windowSize//2+1
    
        nbwindow=numFrames//windowSize
        # create fftw arrays
        ftwIn = ftw.empty_aligned((nbwindow,windowSize,dataChunk.shape[2]), dtype='float64')
        ftwOut = ftw.empty_aligned((nbwindow,windowSize//2+1,dataChunk.shape[2]), dtype='complex128')
    
        #ftwIn = ftw.empty_aligned(windowSize, dtype='complex128')
        #ftwOut = ftw.empty_aligned(windowSize, dtype='complex128')
        fftObject = ftw.FFTW(ftwIn,ftwOut, axes=(1,), flags=('FFTW_MEASURE','FFTW_DESTROY_INPUT',))
        # perform DFT on the data chunk
        demodFrames = dataChunk.shape[0]//windowSize
        channelChunks = np.zeros([numChannels,demodFrames,
            dataChunk.shape[1],dataChunk.shape[2]])
        channelChunks = getDFT(dataChunk,channelChunks,
            ftwIn,ftwOut,fftObject,windowSize,numChannels)
        return channelChunks          
    
    def getDFT(data,channelOut,ftwIn,ftwOut,fftObject,
            windowSize,numChannels):
        frameLen = data.shape[0]
        demodFrames = frameLen//windowSize
        printed=0
        nbwindow=data.shape[0]//windowSize
        scale=1.0/windowSize
        for yy in range(data.shape[1]):
            #for xx in range(data.shape[2]):
                index = 0
    
                ftwIn[:] = np.reshape(data[0:nbwindow*windowSize,yy,:],(nbwindow,windowSize,data.shape[2]),order='C')
                fftObject()
                channelOut[:,:,yy,:]=np.transpose(2*np.abs(ftwOut[:,:,:])*scale, (1,0,2))
                #for i in range(nbwindow):
                    #channelOut[:,i,yy,xx] = 2*np.abs(ftwOut[i,:])*scale
    
                if printed==0:
                          for j in range(channelOut.shape[0]):
                              print j,channelOut[j,0,yy,0]
                          printed=1
    
        return channelOut
    
    if __name__ == '__main__':
        seconds=time.time()
        runme()
        print "time: ", time.time()-seconds
    

    Let us know how much it speeds up your computations! I went from 24s to less than 2s on my computer...