matplotlibplotvisualization

plot many signals together in python


I am trying to plot many signals (196) from a multielectrode dataset (https://zenodo.org/records/1411883). I want to plot something cool to see all registers together. I want to get something similar to this:

enter image description here

Is there any python package or way that you would advice me to follow for plotting all this intormation?

I tried createing subplots but they are too small and take too much to generate.


Solution

  • I think plotting 196 signals on the same plot is a bit too much.

    Below is for 30 signals, using a computed offset:

    EDIT: Added splitting across subplots and MultiCursor.

    multi signals on multi subplots

    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.widgets import MultiCursor
    
    
    def main():
        n, t = 196, np.linspace(0, 10, 600)
        signals = create_signals(t, n) + create_signals(t, n) + create_signals(t, n)
    
        nmax_per_col = 30
        n_cols = n // nmax_per_col + bool(n % nmax_per_col)
    
        # Split signals into groups of almost same size
        signals_groups = np.array_split(signals, n_cols, axis=0)
    
        fig, axes = plt.subplots(ncols=n_cols, sharex=True, figsize=(2.5 * n_cols, 8))
        for ax, signals_group in zip(axes, signals_groups):
            plot_multisignals(t, signals_group, ax=ax)
    
        # Need to keep "multi" alive
        multi = MultiCursor(None, tuple(axes.flat), color="r", lw=1)
    
        plt.show()
    
    
    def plot_multisignals(t, signals, horizontal_lines=True, ax=None, **kwargs):
        """
        t: shape (N,)
        signals: shape (M, N)
        """
        if ax is None:
            _, ax = plt.subplots()
    
        # Center the signals and add offsets
        sigs = signals - signals.mean(axis=1, keepdims=True)
        offset = np.max(sigs.max(axis=1)[:-1] - sigs.min(axis=1)[1:])
        sigs = sigs.T + np.arange(len(signals)) * offset
    
        ax.plot(t, sigs, **kwargs)
        if horizontal_lines:
            for i in range(len(signals) + 1):
                ax.axhline(i * offset - offset / 2, color="k", lw=0.5, alpha=0.3)
    
        # Remove y-axis labels because no meaning
        ax.tick_params(left=False, labelleft=False)
        return ax
    
    
    def create_signals(t, n):
        # Dummy function to generate signals
        freqs = np.random.rand(n) * 3 + 1
        amps = np.random.rand(n) * 3.8 + 0.2
        signals = np.zeros((n, len(t)))
        for i, (freq, amp) in enumerate(zip(freqs, amps)):
            signals[i] = amp * np.sin(2 * np.pi * freq * t)
        return signals
    
    
    if __name__ == "__main__":
        main()