pythonnumpyfft

`numpy.irfftn` returns incorrect shape after multiplying `numpy.rfft`-transformed array by scalar factors


I'm writing a code to perform differential operations on 3D fields by exploiting the fact that derivatives in configuration space (x, y, z) become products in Fourier space (kx, ky, kz).

This is my code:

ki = 2*np.pi/L*kvecs(Ng)
k = 2*np.pi/L*knorms(ki)
vel = np.zeros((*field.shape, 3))
for i in range(3):
    velk = rfftn(field)
    print(velk.shape)
    for x in range(Ng):
            for y in range(Ng):
                for z in range(Ng//2+1):
                    velk[x, y, z] *= -1j*ki[x, y, z, i]/k[x, y, z]**2 if k[x, y, z]!=0 else 0

    print(irfftn(velk).shape)
    vel[..., i] = irfftn(velk)

where field is an array of shape (Ng, Ng, Ng) where Ng=5, ki is a (Ng, Ng, Ng, 3)-array with the (kx, ky, kz) components of each pixel, and k is the norm of ki (that is the square root of kx^2+ky^2+kz^2, so it has shape (Ng, Ng, Ng)). The code code clearly doesn't change the shape of velk at all, just multiplies some factors to its values, yet the two print statements give (5, 5, 3) (correct) and (5, 5, 4) (wrong, it should be the original shape of (5, 5, 5)), and the code errors out at the last line because it cannot broadcast the inverse transform.

If I comment out the last three nested loops, the code works, and if I swap the real fft for the complex one (while also changing the range for the z loop from range(Ng//2+1) to range(Ng)) the code works as expected.

Is this a bug or am I missing something here?


Solution

  • I'm not sure, but I think you might need to specify the axes:

    (from https://docs.scipy.org/doc/scipy/reference/generated/scipy.fft.irfftn.html)

    axes: sequence of ints, optional

    Axes over which to compute the inverse FFT. If not given, the last len(s) axes are used, or all axes if s is also not specified.
    

    I think that in your example the last axis (the length 3 x,y,z) is used, which is then padded with 0s (a subtle point in fft, which is does cause it prefers to have power-of-two sized inputs), which is why you get a (5,5,4) array

    out: ndarray

    The **truncated or zero-padded input**, transformed along the axes indicated by axes, or by a combination of s or x, as explained in the parameters section above.