pythonnumpymatplotlibscipyfft

Inverse Fast Fourier Transform (ifft2) of scipy not working for fourier optics


I'm following a tutorial on youtube on Fourier Optics in python, to simulate diffraction of light through a slit.
The video in question
Source Code of video
Now, I'm trying to implement the get_U(z, k) function and then display the corresponding plot below it, as shown in the video (I've got barebones knowledge about this topic), however, i just can't seem to get the plot working (white plot is visible the entire time). Upon inspection, I've discovered that the U variable just consists of a bunch of (nan+nanj) values, which I think shouldn't be the case. I've crosschecked the formula and it looks perfect. I also realise that, sometimes the np.sqrt() has to deal with negative values, but adding neither a np.abs() nor a np.where()(to convert negatives to zero) gives me the intended output.
My code:

import numpy as np
import scipy as sp
from scipy.fft import fft2, ifft2, fftfreq, fftshift
import matplotlib.pyplot as plt
import pint

plt.style.use(['grayscale'])
u = pint.UnitRegistry()

D = 0.1 * u.mm
lam = 660 * u.mm

x = np.linspace(-2, 2, 1600) * u.mm
xv, yv = np.meshgrid(x, x)

U0 = (np.abs(xv) < D/2) * (np.abs(yv) < 0.5 * u.mm)
U0 = U0.astype(float)

A = fft2(U0)
kx = fftfreq(len(x), np.diff(x)[0]) * 2 * np.pi
kxv, kyv = np.meshgrid(kx, kx)

def get_U(z, k):
  return ifft2(A*np.exp(1j*z.magnitude*np.sqrt(k.magnitude**2 - kxv.magnitude**2 - kyv.magnitude**2)))
k = 2*np.pi/(lam)
d = 3 * u.cm
U = get_U(d, k)

plt.figure(figsize=(5, 5))
plt.pcolormesh(xv, yv, np.abs(U), cmap='inferno')
plt.xlabel('$x$ [mm]')
plt.ylabel('$y$ [mm]')
plt.title('Single slit diffraction')
plt.show()

Solution

  • Your units of lam are wrong - if you intend to use pint (but I suggest that you don't) then they should be in nm, not mm.

    When you have made that change I suggest that you remove all reference to pint and mixed units and work entirely in a single set of length units (here, m). This is because units appear to be stripped when creating some of the numpy arrays. You can use scientific notation (e.g. 1e-9) to imply the units. Then you get what I think you require.

    enter image description here

    import numpy as np
    import scipy as sp
    from scipy.fft import fft2, ifft2, fftfreq, fftshift
    import matplotlib.pyplot as plt
    
    plt.style.use(['grayscale'])
    
    D = 0.1   * 1e-3
    lam = 660 * 1e-9
    
    x = np.linspace(-2, 2, 1600) * 1e-3
    xv, yv = np.meshgrid(x, x)
    
    U0 = (np.abs(xv) < D/2) * (np.abs(yv) < 0.5 * 1e-3)
    U0 = U0.astype(float)
    
    A = fft2(U0)
    kx = fftfreq(len(x), np.diff(x)[0]) * 2 * np.pi
    kxv, kyv = np.meshgrid(kx, kx)
    
    def get_U(z, k):
      return ifft2(A*np.exp(1j*z*np.sqrt(k**2 - kxv**2 - kyv**2)))
    k = 2*np.pi/(lam)
    d = 3 * 1e-2
    U = get_U(d, k)
    
    plt.figure(figsize=(5, 5))
    plt.pcolormesh(xv*1e3, yv*1e3, np.abs(U), cmap='inferno')
    plt.xlabel('$x$ [mm]')
    plt.ylabel('$y$ [mm]')
    plt.title('Single slit diffraction')
    plt.show()