pythonscipysignal-processingekg

How to flatten a digital signal whose baseline jumps up and down with Python?


I'm analyzing the electrocardiogram (EKG) built-in dataset from SciPy, a segment of which looks like below: Screenshot of Scipy EKG from a Kaggle notebook

One problem with the data above is that the baseline of the EKG jumps up and down a lot. If you're not familiar with EKGs or heartbeat analysis, they're supposed to be flat with a few spikes of the "QRS complex" (AKA the actual heartbeat), like below:

An image of a fake EKG from the American Heart Association

Medical device companies that make EKG monitors use some flattening and filtering functions to make EKGs smoother and easier to read, since natural human body movements are unavoidable and will cause the EKG signal to jump around as shown above. But I don't know which filters/functions they use.

How can I use SciPy or write a custom function to flatten the SciPy EKG dataset above?

What I have tried

I have tried reading the SciPy documentation for signal processing, but have not yet found any functions that "flatten" data or bring it to baseline. I'm a novice in digital signals processing and want to know if there is a better or official way to do this.

I have tried shifting the values up or down by some moving average, but there's no way that simple addition or subtraction of Y-values is the correct way. That is too hacky.

Thanks for your help.

Question clarification: fixing "baseline wandering"

Another user in the comments (Christoph Rackwitz) suggested a high-pass filter. While Googling high-pass filters, I found an image from a research paper on ECGs:

EKG with baseline wandering adjustment

They said the first image had "baseline wandering," which is what I'm really trying to answer with this question. How do I fix baseline wandering?

Two useful research papers

Considering the research paper I mentioned in the last section, as well as another one I just found about fixing baseline wandering in ECGs, I'll read these and see what I find.


Solution

  • Here's a showcase of some lowpass filters.

    You'll want to read up on "causal" filters vs non-causal filters, as well as the difference between FIR and IIR.

    Loading the data:

    signal = scipy.datasets.electrocardiogram()
    fs = 360 # say the docs
    time = np.arange(signal.size) / fs # for plotting only
    

    Explore the signal:

    fig, axs = plt.subplots(3, 1, figsize=(15, 15))
    axs[0].plot(time[30000:31000], signal[30000:31000])
    axs[1].plot(time[30000:40000], signal[30000:40000])
    axs[2].plot(time, signal)
    axs[0].set_xlabel('Time (s)')
    axs[1].set_xlabel('Time (s)')
    axs[2].set_xlabel('Time (s)')
    plt.show()
    

    figure 1

    Trying a few filters:

    # Butterworth, first order, 0.5 Hz cutoff
    lowpass = scipy.signal.butter(1, 0.5, btype='lowpass', fs=fs, output='sos')
    lowpassed = scipy.signal.sosfilt(lowpass, signal)
    highpassed = signal - lowpassed
    
    fig, axs = plt.subplots(2, 1, figsize=(15, 10))
    axs[0].plot(time[30000:32000], signal[30000:32000])
    axs[0].plot(time[30000:32000], lowpassed[30000:32000])
    axs[1].plot(time[30000:32000], highpassed[30000:32000])
    axs[0].set_xlabel('Time (s)')
    axs[1].set_xlabel('Time (s)')
    axs[0].set_ylim([-3, +3])
    axs[1].set_ylim([-3, +3])
    plt.show()
    

    figure 2

    # take note of these coefficients:
    >>> scipy.signal.butter(1, 0.5, btype='lowpass', fs=fs, output='ba')
    (array([0.00434, 0.00434]), array([ 1.     , -0.99131]))
    # and compare to the following...
    
    # Almost the same thing, different formulation: "exponential average"
    # y += (x - y) * alpha  # applied to each value X of the signal to produce a new Y
    
    alpha = 1/100 # depends on fs and desired cutoff frequency
    lowpassed = scipy.signal.lfilter([alpha], [1, -(1-alpha)], signal)
    highpassed = signal - lowpassed
    
    fig, axs = plt.subplots(2, 1, figsize=(15, 10))
    axs[0].plot(time[30000:32000], signal[30000:32000])
    axs[0].plot(time[30000:32000], lowpassed[30000:32000])
    axs[1].plot(time[30000:32000], highpassed[30000:32000])
    axs[0].set_xlabel('Time (s)')
    axs[1].set_xlabel('Time (s)')
    axs[0].set_ylim([-3, +3])
    axs[1].set_ylim([-3, +3])
    plt.show()
    

    figure 3

    # the first two filters were "causal", i.e. only using past samples.
    # downside: lag, phase shift, i.e. the lowpass doesn't quite match/track the signal.
    # "non-causal" filters can use future samples.
    # this allows to remove the phase shift but the processing introduces a delay instead.
    # this delay is irrelevant for offline processing or if it's considered "small enough".
    # the following are non-causal.
    
    # median filter. interpret peaks as outliers, so this reveals whatever can be considered "baseline".
    # can be causal if the kernel window only covers the past but that introduces lag (noticeable when the signal drifts actively).
    # might need another pass of smoothing, on the median filter, before subtracting.
    # median filtering CAN be cheap, if using the right data structure. scipy implementation seems less smart, takes noticeable time.
    
    lowpassed = scipy.signal.medfilt(signal, kernel_size=fs+1)
    highpassed = signal - lowpassed
    
    fig, axs = plt.subplots(2, 1, figsize=(15, 10))
    axs[0].plot(time[30000:32000], signal[30000:32000])
    axs[0].plot(time[30000:32000], lowpassed[30000:32000])
    axs[1].plot(time[30000:32000], highpassed[30000:32000])
    axs[0].set_xlabel('Time (s)')
    axs[1].set_xlabel('Time (s)')
    axs[0].set_ylim([-3, +3])
    axs[1].set_ylim([-3, +3])
    plt.show()
    

    figure 4

    lowpassed = scipy.ndimage.gaussian_filter1d(signal, sigma=0.2 * fs, order=0)
    highpassed = signal - lowpassed
    
    fig, axs = plt.subplots(2, 1, figsize=(15, 10))
    axs[0].plot(time[30000:32000], signal[30000:32000])
    axs[0].plot(time[30000:32000], lowpassed[30000:32000])
    axs[1].plot(time[30000:32000], highpassed[30000:32000])
    axs[0].set_xlabel('Time (s)')
    axs[1].set_xlabel('Time (s)')
    axs[0].set_ylim([-3, +3])
    axs[1].set_ylim([-3, +3])
    plt.show()
    

    figure 5