Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
DominiqueMakowski committed Aug 23, 2023
1 parent c3d6d0c commit 85f63a5
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 128 deletions.
85 changes: 62 additions & 23 deletions neurokit2/ecg/ecg_peaks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import matplotlib.pyplot as plt
import numpy as np

from ..signal import signal_fixpeaks, signal_formatpeaks
from .ecg_findpeaks import ecg_findpeaks

Expand All @@ -7,6 +10,7 @@ def ecg_peaks(
sampling_rate=1000,
method="neurokit",
correct_artifacts=False,
show=False,
**kwargs
):
"""**Find R-peaks in an ECG signal**
Expand Down Expand Up @@ -86,10 +90,9 @@ def ecg_peaks(
import neurokit2 as nk
ecg = nk.ecg_simulate(duration=10, sampling_rate=250)
signals, info = nk.ecg_peaks(ecg, sampling_rate=250, correct_artifacts=True)
@savefig p_ecg_peaks1.png scale=100%
nk.events_plot(info["ECG_R_Peaks"], ecg)
signals, info = nk.ecg_peaks(ecg, sampling_rate=250, correct_artifacts=True, show=True)
@suppress
plt.close()
Expand Down Expand Up @@ -224,27 +227,6 @@ def ecg_peaks(
* T. Koka and M. Muma, "Fast and Sample Accurate R-Peak Detection for Noisy ECG Using
Visibility Graphs," 2022 44th Annual International Conference of the IEEE Engineering in
Medicine & Biology Society (EMBC), 2022, pp. 121-126.
* ``nabian2018``
* ``gamboa2008``
* ``hamilton2002``
* ``christov2004``
* ``engzeemod2012``
* ``elgendi2010``
* ``kalidas2017``
* ``rodrigues2021``
* ``koka2022``
* ``promac``
* **Unpublished.** It runs different methods and derives a probability index using
convolution. See this discussion for more information on the method:
Expand Down Expand Up @@ -288,4 +270,61 @@ def ecg_peaks(

info["sampling_rate"] = sampling_rate # Add sampling rate in dict info

if show is True:
_ecg_peaks_plot(ecg_cleaned, sampling_rate, info)

return signals, info


# =============================================================================
# Internals
# =============================================================================
def _ecg_peaks_plot(ecg_cleaned, sampling_rate=1000, info=None, raw=None, ax=None):
x_axis = np.linspace(0, len(ecg_cleaned) / sampling_rate, len(ecg_cleaned))

# Prepare plot
if ax is None:
fig, ax = plt.subplots()

# fig, ax = plt.subplots()
ax.set_xlabel("Time (seconds)")

# Raw Signal ---------------------------------------------------------------
if raw is not None:
ax.plot(x_axis, raw, color="#B0BEC5", label="Raw signal", zorder=1)

# Peaks -------------------------------------------------------------------
ax.scatter(
x_axis[info["ECG_R_Peaks"]],
ecg_cleaned[info["ECG_R_Peaks"]],
color="#FFC107",
label="R-peaks",
zorder=2,
)

# TODO
# # Artifacts ---------------------------------------------------------------
# def _plot_artifact(artifact, color, label, ax):
# if artifact in info.keys() and len(info[artifact]) > 0:
# ax.scatter(
# x_axis[info[artifact]],
# ecg_cleaned[info[artifact]],
# color=color,
# label=label,
# marker="x",
# zorder=2,
# )

# _plot_artifact("ECG_fixpeaks_missed", "#1E88E5", "Missed Peaks", ax)
# _plot_artifact("ECG_fixpeaks_longshort", "#1E88E5", "Long/Short", ax)

# Clean Signal ------------------------------------------------------------
ax.plot(
x_axis,
ecg_cleaned,
color="#F44336",
label="Cleaned",
zorder=3,
linewidth=1,
)
ax.legend(loc="upper right")
181 changes: 80 additions & 101 deletions neurokit2/ecg/ecg_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .ecg_segment import ecg_segment


def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=1000, show_type="default"):
def ecg_plot(ecg_signals, info=None, sampling_rate=1000, show_type="default"):
"""**Visualize ECG data**
Plot ECG signals and R-peaks.
Expand All @@ -20,9 +20,8 @@ def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=1000, show_type="default"):
----------
ecg_signals : DataFrame
DataFrame obtained from ``ecg_process()``.
rpeaks : dict
The samples at which the R-peak occur. Dict returned by
``ecg_process()``. Defaults to ``None``.
info : dict
The information Dict returned by ``ecg_process()``. Defaults to ``None``.
sampling_rate : int
The sampling frequency of ``ecg_cleaned`` (in Hz, i.e., samples/second). Defaults to 1000.
show_type : str
Expand Down Expand Up @@ -57,7 +56,7 @@ def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=1000, show_type="default"):
# Plot
@savefig p_ecg_plot.png scale=100%
nk.ecg_plot(signals, sampling_rate=1000, show_type='default')
nk.ecg_plot(signals, info, sampling_rate=1000, show_type='default')
@suppress
plt.close()
Expand All @@ -73,99 +72,79 @@ def ecg_plot(ecg_signals, rpeaks=None, sampling_rate=1000, show_type="default"):
peaks = np.where(ecg_signals["ECG_R_Peaks"] == 1)[0]

# Prepare figure and set axes.
if show_type in ["default", "full"]:
x_axis = np.linspace(0, len(ecg_signals) / sampling_rate, len(ecg_signals))
gs = matplotlib.gridspec.GridSpec(2, 2, width_ratios=[2 / 3, 1 / 3])
fig = plt.figure(constrained_layout=False)
ax0 = fig.add_subplot(gs[0, :-1])
ax0.set_xlabel("Time (seconds)")

ax1 = fig.add_subplot(gs[1, :-1], sharex=ax0)
ax2 = fig.add_subplot(gs[:, -1])

fig.suptitle("Electrocardiogram (ECG)", fontweight="bold")

# Plot cleaned, raw ECG, R-peaks and signal quality.
ax0.set_title("Raw and Cleaned Signal")

quality = rescale(
ecg_signals["ECG_Quality"],
to=[np.min(ecg_signals["ECG_Clean"]), np.max(ecg_signals["ECG_Clean"])],
)
minimum_line = np.full(len(x_axis), quality.min())

# Plot quality area first
ax0.fill_between(
x_axis,
minimum_line,
quality,
alpha=0.12,
zorder=0,
interpolate=True,
facecolor="#4CAF50",
label="Quality",
)

# Plot signals
ax0.plot(x_axis, ecg_signals["ECG_Raw"], color="#B0BEC5", label="Raw", zorder=1)
ax0.plot(
x_axis,
ecg_signals["ECG_Clean"],
color="#F44336",
label="Cleaned",
zorder=1,
linewidth=1.5,
)
ax0.scatter(
x_axis[peaks],
ecg_signals["ECG_Clean"][peaks],
color="#FFC107",
label="R-peaks",
zorder=2,
)

# Optimize legend
handles, labels = ax0.get_legend_handles_labels()
order = [2, 0, 1, 3]
ax0.legend(
[handles[idx] for idx in order],
[labels[idx] for idx in order],
loc="upper right",
)

# Plot Heart Rate
ax1 = _signal_rate_plot(
ecg_signals["ECG_Rate"].values,
peaks,
sampling_rate=sampling_rate,
title="Heart Rate",
ytitle="Beats per minute (bpm)",
color="#FF5722",
color_mean="#FF9800",
color_points="red",
ax=ax1,
)

# Plot individual heart beats
ax2 = ecg_segment(
ecg_signals["ECG_Clean"], peaks, sampling_rate, show="return", ax=ax2
)

# Plot artifacts
if show_type in ["artifacts", "full"]:
if sampling_rate is None:
raise ValueError(
"NeuroKit error: ecg_plot(): Sampling rate must be specified for artifacts"
" to be plotted."
)

if rpeaks is None:
_, rpeaks = ecg_peaks(ecg_signals["ECG_Clean"], sampling_rate=sampling_rate)

fig = signal_fixpeaks(
rpeaks,
sampling_rate=sampling_rate,
iterative=True,
show=True,
method="Kubios",
)
x_axis = np.linspace(0, len(ecg_signals) / sampling_rate, len(ecg_signals))
gs = matplotlib.gridspec.GridSpec(2, 2, width_ratios=[2 / 3, 1 / 3])
fig = plt.figure(constrained_layout=False)
ax0 = fig.add_subplot(gs[0, :-1])
ax0.set_xlabel("Time (seconds)")

ax1 = fig.add_subplot(gs[1, :-1], sharex=ax0)
ax2 = fig.add_subplot(gs[:, -1])

fig.suptitle("Electrocardiogram (ECG)", fontweight="bold")

# Plot cleaned, raw ECG, R-peaks and signal quality.
ax0.set_title("Raw and Cleaned Signal")

quality = rescale(
ecg_signals["ECG_Quality"],
to=[np.min(ecg_signals["ECG_Clean"]), np.max(ecg_signals["ECG_Clean"])],
)
minimum_line = np.full(len(x_axis), quality.min())

# Plot quality area first
ax0.fill_between(
x_axis,
minimum_line,
quality,
alpha=0.12,
zorder=0,
interpolate=True,
facecolor="#4CAF50",
label="Quality",
)

# Plot signals
ax0.plot(x_axis, ecg_signals["ECG_Raw"], color="#B0BEC5", label="Raw", zorder=1)
ax0.plot(
x_axis,
ecg_signals["ECG_Clean"],
color="#F44336",
label="Cleaned",
zorder=1,
linewidth=1.5,
)
ax0.scatter(
x_axis[peaks],
ecg_signals["ECG_Clean"][peaks],
color="#FFC107",
label="R-peaks",
zorder=2,
)

# Optimize legend
handles, labels = ax0.get_legend_handles_labels()
order = [2, 0, 1, 3]
ax0.legend(
[handles[idx] for idx in order],
[labels[idx] for idx in order],
loc="upper right",
)

# Plot Heart Rate
ax1 = _signal_rate_plot(
ecg_signals["ECG_Rate"].values,
peaks,
sampling_rate=sampling_rate,
title="Heart Rate",
ytitle="Beats per minute (bpm)",
color="#FF5722",
color_mean="#FF9800",
color_points="red",
ax=ax1,
)

# Plot individual heart beats
ax2 = ecg_segment(
ecg_signals["ECG_Clean"], peaks, sampling_rate, show="return", ax=ax2
)
12 changes: 9 additions & 3 deletions neurokit2/signal/signal_fixpeaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def signal_fixpeaks(
# Simulate ECG data and add noisy period
ecg = nk.ecg_simulate(duration=240, sampling_rate=250, noise=2, random_state=42)
ecg[20000:30000] += np.random.normal(loc=0, scale=0.4, size=10000)
ecg[20000:30000] += np.random.uniform(size=10000)
ecg[40000:43000] = 0
# Identify and Correct Peaks using "Kubios" Method
rpeaks_uncorrected = nk.ecg_findpeaks(ecg, method="pantompkins", sampling_rate=250)
Expand Down Expand Up @@ -126,7 +127,7 @@ def signal_fixpeaks(
peaks = np.sort(np.append(peaks, [1350, 11350, 18350])) # add artifacts
# Identify and Correct Peaks using 'NeuroKit' Method
peaks_corrected = nk.signal_fixpeaks(
info, peaks_corrected = nk.signal_fixpeaks(
peaks=peaks, interval_min=0.5, interval_max=1.5, method="neurokit"
)
Expand Down Expand Up @@ -201,7 +202,11 @@ def _signal_fixpeaks_neurokit(
)
peaks_clean = valid_peaks

info = {"method": "neurokit"}
info = {
"method": "neurokit",
"extra": [i for i in peaks if i not in peaks_clean],
"missed": [i for i in peaks_clean if i not in peaks],
}
return info, peaks_clean


Expand Down Expand Up @@ -620,6 +625,7 @@ def _get_which_endswith(info, string):
)
ax3.add_patch(poly3)
ax3.legend(loc="upper right")
plt.tight_layout()


# =============================================================================
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_ecg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_ecg_peaks():

assert signals.shape == (24000, 1)
assert np.allclose(signals["ECG_R_Peaks"].values.sum(dtype=np.int64), 136, atol=1)
assert info["ECG_fixpeaks_longshort"] == [17]
assert 17 in info["ECG_fixpeaks_longshort"]


def test_ecg_process():
Expand Down

0 comments on commit 85f63a5

Please sign in to comment.