pythontensorflowkeras

How to convert image in fast fourier transform signal?


I am trying to convert image into fast fourier transform signal and used the following peace of code:

fake_A1 = tf.signal.fft2d(fake_A1)

where input image type is: <class 'numpy.ndarray'> but I am getting following error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Value for attr 'Tcomplex' of float is not in the list of allowed values: complex64, complex128
; NodeDef: {{node FFT2D}}; Op<name=FFT2D; signature=input:Tcomplex -> output:Tcomplex; attr=Tcomplex:type,default=DT_COMPLEX64,allowed=[DT_COMPLEX64, DT_COMPLEX128]> [Op:FFT2D]

How may I make it solve?


Solution

  • P.S.: If you want to make edits then do it on your question, not as an answer.

    Now coming to the topic: 2D FFT of an image. Firstly an image is of shape:

    image.shape = (3,rows,columns)
    

    Where 3 stands for 3 matrices which are of 2 dimensions, corresponding to RGB. Hence to carry out 2D FFT, we first need to flatten this by converting it to grayscale. I found a useful tutorial here on ThePythonCodingBook. Ill add the code here for TL:DR purposes.

    import matplotlib.pyplot as plt
    image_filename = "Earth.png"
    # Read and process image
    image = plt.imread(image_filename)
    image = image[:, :, :3].mean(axis=2)  # Convert to grayscale
    print(image.shape)
    plt.set_cmap("gray")
    plt.imshow(image)
    plt.axis("off")
    plt.show()
    
    import numpy as np
    import matplotlib.pyplot as plt
    image_filename = "Earth.png"
    def calculate_2dft(input):
        ft = np.fft.ifftshift(input)
        ft = np.fft.fft2(ft)
        return np.fft.fftshift(ft)
    # Read and process image
    image = plt.imread(image_filename)
    image = image[:, :, :3].mean(axis=2)  # Convert to grayscale
    plt.set_cmap("gray")
    ft = calculate_2dft(image)
    plt.subplot(121)
    plt.imshow(image)
    plt.title('Grayscale Image')
    plt.axis("off")
    plt.subplot(122)
    plt.imshow(np.log(abs(ft)))
    plt.title("2D FFT")
    plt.axis("off")
    plt.show()
    

    I'll add the grayscale and FFT plots as well.

    Grayscale and FFT plots

    Hope this helps.