pythonfftpyfftw

Phase discontinuity happens during ifft operation using pyfftw and scipy.fft


When doing ifft of a 2D ndarray using pyfftw, I found the resultant phase is discontinuous in many positions. My code is as follows:

import numpy as np
import pyfftw
from scipy.fft import ifftshift,fftshift
import matplotlib.pyplot as plt

N = 256
kx = np.linspace(-np.floor(N/2),np.ceil(N/2)-1,N)
kX,kY = np.meshgrid(kx,kx)
kR = np.sqrt(kX**2 + kY**2)
mask = np.where((kR<=15),1,0)
ifft_obj = pyfftw.builders.ifft2(ifftshift(mask))
wave = fftshift(ifft_obj())

plt.imshow(np.angle(wave),cmap='jet')
plt.colorbar()

The phase image is as follows enter image description here The minimum phase value is -3.141592653589793 and the maximum value is 3.141592653589793 and their difference is larger than 2pi. Using scipy.fft just gets the same results. However, when I turn to Matlab, the result looks more reasonable. My code is:

N = 256;
kx = linspace(-floor(N/2),ceil(N/2)-1,N);
[kX,kY] = meshgrid(kx,kx);
kR = sqrt(kX.^2 + kY.^2);
mask = single(kR<=15);
wave = fftshift(ifft2(ifftshift(mask)));
imshow(angle(wave));
caxis([min(angle(wave),[],'all') max(angle(wave),[],'all')]);
axis image; colormap jet;colorbar;

The phase image is enter image description here

I wonder what leads to the phase discontinuity in python code and how can I correct it.


Solution

  • Your input is symmetric, which leads to a purely real transform (phase is 0 for positive and π for negative values). But because of numerical inaccuracies of the FFT algorithm, the result has very small imaginary values. Thus, the phase deviates a little bit from 0 and π. In your colormapped image, small deviations from 0 are not seen, but small deviations from π can lead to values close to -π (as already discussed by VPfB).

    MATLAB does not show this issue because MATLAB's ifft recognizes that the input is (conjugate) symmetric and outputs a purely real image. It simply ignores those small imaginary values.

    You can do the same in Python with

    wave = np.real_if_close(wave, tol=1000)
    

    The tolerance here is np.finfo(wave.dtype).eps * tol (2.22 10-13 for double float). Adjust as necessary.