python-3.xmatplotlibjupyter-notebookgraphics2d

Show more than initial frame for animation of wave propagation with time in Jupyter


I am trying to see the wave propagation as time progresses. For this I am using finite difference method. However, with the latest jupyter notebook I am unable to obtain the figures. When I run the code, it doesn't show anything. It is only with %matplotlib inline I can see the initial plot, but it doesn't update.

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib import gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import warnings
warnings.filterwarnings("ignore")

def run_2d_wave_animation():
    # ----- Simulation Parameters -----
    nx = 200  # reduced size for smooth animation, increase later
    nz = nx
    dx = dz = 1
    c0 = 580
    isx = isz = 100
    irx = irz = 150
    nt = 502
    dt = 0.0010
    f0 = 25
    t0 = 2. / f0
    op = 3

    # ----- Stability Check -----
    eps = c0 * dt / dx
    if eps > 1:
        print('CFL condition not satisfied.')
    else:
        print('CFL condition OK.')

    # ----- Grids & Source -----
    x = dx * np.arange(nx)
    z = dz * np.arange(nz)
    time = dt * np.arange(nt)
    c = c0 * np.ones((nz, nx))

    p = np.zeros((nz, nx))
    pnew = np.zeros((nz, nx))
    pold = np.zeros((nz, nx))
    d2px = np.zeros((nz, nx))
    d2pz = np.zeros((nz, nx))
    seis = np.zeros(nt)

    src = -8 * (time - t0) * f0 * (np.exp(-1.0 * (4 * f0)**2 * (time - t0)**2))

    # ----- Analytical Green's Function -----
    G = np.zeros(nt)
    r = np.sqrt((x[isx] - x[irx])**2 + (z[isz] - z[irz])**2)
    for it in range(nt):
        if (time[it] - r / c0) >= 0:
            G[it] = (1 / (2 * np.pi * c0**2)) * (1 / np.sqrt((time[it]**2) - (r**2 / c0**2)))

    Gc = np.convolve(G, src * dt)[:nt]

    # ----- Set up Plot -----
    fig = plt.figure(figsize=(10, 6))
    gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1], wspace=0.3)

    # --- Left Panel: Wave Field ---
    ax1 = plt.subplot(gs[0])
    lim = np.max(np.abs(Gc))
    im = ax1.imshow(p, vmin=-lim, vmax=lim, cmap='RdBu', interpolation='nearest')
    divider = make_axes_locatable(ax1)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im, cax=cax)
    ax1.plot(isx, isz, 'r*', markersize=10)
    ax1.plot(irx, irz, 'k^', markersize=8)
    ax1.set_title("Wave Field at t = 0")
    ax1.set_xlim(0, nx)
    ax1.set_ylim(0, nz)

    # --- Right Panel: Seismogram ---
    ax2 = plt.subplot(gs[1])
    line_fd, = ax2.plot(time, seis, 'b-', label='Numerical (FD)')
    line_analytic, = ax2.plot(time, Gc, 'r--', label='Analytical')
    marker, = ax2.plot([0], [0], 'ko', markersize=6)
    ax2.set_xlim(time[0], time[-1])
    ax2.set_ylim(min(Gc.min(), -1e-6), max(Gc.max(), 1e-6))
    ax2.set_title('Seismogram')
    ax2.set_xlabel('Time (s)')
    ax2.set_ylabel('Amplitude')
    ax2.legend(loc='upper right')

    # ----- Animation Update Function -----
    def update(it):
        nonlocal p, pold, pnew, seis

        if op == 3:
            d2px[1:-1, :] = (p[2:, :] - 2 * p[1:-1, :] + p[0:-2, :]) / dx**2
            d2pz[:, 1:-1] = (p[:, 2:] - 2 * p[:, 1:-1] + p[:, 0:-2]) / dz**2
        else:
            raise NotImplementedError("Only 3-point FD supported in this example")

        pnew = 2 * p - pold + (c**2) * dt**2 * (d2px + d2pz)
        pnew[isz, isx] += src[it] / (dx * dz) * dt**2

        # Soft boundary conditions (optional)
        pnew[0, :] = pnew[1, :]
        pnew[-1, :] = pnew[-2, :]
        pnew[:, 0] = pnew[:, 1]
        pnew[:, -1] = pnew[:, -2]

        pold, p = p, pnew
        seis[it] = p[irz, irx]

        im.set_data(p)  # use set_data for imshow
        ax1.set_title(f"Wave Field at t = {it*dt:.3f}s")
        line_fd.set_ydata(seis)
        marker.set_data(time[it], seis[it])
        fig.canvas.draw_idle()  # force redraw for widget backend

        return im, line_fd, marker

    ani = FuncAnimation(fig, update, frames=nt, interval=10, blit=False, repeat=False)

    plt.show()
    return ani

ani = run_2d_wave_animation()

Solution

  • I think you may have a syntax error? This below runs an animation. All I did was change marker.set_data(time[it], seis[it]) to marker.set_data([time[it]], [seis[it]]). The set_data method expects sequences, not scalars

    %matplotlib ipympl
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    from matplotlib import gridspec
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    import warnings
    warnings.filterwarnings("ignore")
    
    def run_2d_wave_animation():
        # ----- Simulation Parameters -----
        nx = 200  # reduced size for smooth animation, increase later
        nz = nx
        dx = dz = 1
        c0 = 580
        isx = isz = 100
        irx = irz = 150
        nt = 502
        dt = 0.0010
        f0 = 25
        t0 = 2. / f0
        op = 3
    
        # ----- Stability Check -----
        eps = c0 * dt / dx
        if eps > 1:
            print('CFL condition not satisfied.')
        else:
            print('CFL condition OK.')
    
        # ----- Grids & Source -----
        x = dx * np.arange(nx)
        z = dz * np.arange(nz)
        time = dt * np.arange(nt)
        c = c0 * np.ones((nz, nx))
    
        p = np.zeros((nz, nx))
        pnew = np.zeros((nz, nx))
        pold = np.zeros((nz, nx))
        d2px = np.zeros((nz, nx))
        d2pz = np.zeros((nz, nx))
        seis = np.zeros(nt)
    
        src = -8 * (time - t0) * f0 * (np.exp(-1.0 * (4 * f0)**2 * (time - t0)**2))
    
        # ----- Analytical Green's Function -----
        G = np.zeros(nt)
        r = np.sqrt((x[isx] - x[irx])**2 + (z[isz] - z[irz])**2)
        for it in range(nt):
            if (time[it] - r / c0) >= 0:
                G[it] = (1 / (2 * np.pi * c0**2)) * (1 / np.sqrt((time[it]**2) - (r**2 / c0**2)))
    
        Gc = np.convolve(G, src * dt)[:nt]
    
        # ----- Set up Plot -----
        fig = plt.figure(figsize=(10, 6))
        gs = gridspec.GridSpec(1, 2, width_ratios=[1, 1], wspace=0.3)
    
        # --- Left Panel: Wave Field ---
        ax1 = plt.subplot(gs[0])
        lim = np.max(np.abs(Gc))
        im = ax1.imshow(p, vmin=-lim, vmax=lim, cmap='RdBu', interpolation='nearest')
        divider = make_axes_locatable(ax1)
        cax = divider.append_axes("right", size="5%", pad=0.05)
        fig.colorbar(im, cax=cax)
        ax1.plot(isx, isz, 'r*', markersize=10)
        ax1.plot(irx, irz, 'k^', markersize=8)
        ax1.set_title("Wave Field at t = 0")
        ax1.set_xlim(0, nx)
        ax1.set_ylim(0, nz)
    
        # --- Right Panel: Seismogram ---
        ax2 = plt.subplot(gs[1])
        line_fd, = ax2.plot(time, seis, 'b-', label='Numerical (FD)')
        line_analytic, = ax2.plot(time, Gc, 'r--', label='Analytical')
        marker, = ax2.plot([0], [0], 'ko', markersize=6)
        ax2.set_xlim(time[0], time[-1])
        ax2.set_ylim(min(Gc.min(), -1e-6), max(Gc.max(), 1e-6))
        ax2.set_title('Seismogram')
        ax2.set_xlabel('Time (s)')
        ax2.set_ylabel('Amplitude')
        ax2.legend(loc='upper right')
    
        # ----- Animation Update Function -----
        def update(it):
            nonlocal p, pold, pnew, seis
    
            if op == 3:
                d2px[1:-1, :] = (p[2:, :] - 2 * p[1:-1, :] + p[0:-2, :]) / dx**2
                d2pz[:, 1:-1] = (p[:, 2:] - 2 * p[:, 1:-1] + p[:, 0:-2]) / dz**2
            else:
                raise NotImplementedError("Only 3-point FD supported in this example")
    
            pnew = 2 * p - pold + (c**2) * dt**2 * (d2px + d2pz)
            pnew[isz, isx] += src[it] / (dx * dz) * dt**2
    
            # Soft boundary conditions (optional)
            pnew[0, :] = pnew[1, :]
            pnew[-1, :] = pnew[-2, :]
            pnew[:, 0] = pnew[:, 1]
            pnew[:, -1] = pnew[:, -2]
    
            pold, p = p, pnew
            seis[it] = p[irz, irx]
    
            im.set_data(p)  # use set_data for imshow
            ax1.set_title(f"Wave Field at t = {it*dt:.3f}s")
            line_fd.set_ydata(seis)
            marker.set_data([time[it]], [seis[it]])
            fig.canvas.draw_idle()
    
            return im, line_fd, marker
    
        ani = FuncAnimation(fig, update, frames=nt, interval=10, blit=False, repeat=False)
    
        plt.show()
        return ani
    
    ani = run_2d_wave_animation()
    

    Try it yourself without installing anything in a session where ipympl works

    Go here and click a 'launch binder' badge to launch a session. Open a new notebook and paste in my version of the code.