-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add MEGnet to make MNE-ICALabel work on MEG data #207
base: main
Are you sure you want to change the base?
Changes from all commits
ec28e4f
6f272b1
b3433c8
34c2f31
5b7dc9c
96ed02d
989cb40
bc64aa2
8af24b1
8f5e0e6
5aabe3f
bd7f8cc
59aedfb
067849c
a0da5ee
143df13
58a719a
b89c864
8465017
a0e526d
40b1074
49f39d1
4f2d43a
bbce3cc
19e0260
c47d582
686fda1
c5897e1
88d0619
7d0a7e2
25824fb
036599b
30cc94a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import numpy as np | ||
from numpy.typing import NDArray | ||
|
||
|
||
def _cart2sph(x, y, z): | ||
xy = np.sqrt(x * x + y * y) | ||
r = np.sqrt(x * x + y * y + z * z) | ||
theta = np.arctan2(y, x) | ||
phi = np.arctan2(z, xy) | ||
return r, theta, phi | ||
|
||
|
||
def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple) -> dict: | ||
"""Generate head outlines for topomap plotting. | ||
|
||
This is a modified version of mne.viz.topomap._make_head_outlines. | ||
The difference between this function and the original one is that | ||
head_x and head_y here are scaled by a factor of 1.01 to make topomap | ||
fit the 120x120 pixel size. | ||
Also, removed the ear and nose outlines for not needed in MEGnet. | ||
|
||
Parameters | ||
---------- | ||
sphere : NDArray | ||
The sphere parameters (x, y, z, radius). | ||
pos : NDArray | ||
The 2D sensor positions. | ||
clip_origin : tuple | ||
The origin of the clipping circle. | ||
|
||
Returns | ||
------- | ||
dict | ||
Dictionary containing the head outlines and mask positions. | ||
|
||
""" | ||
x, y, _, radius = sphere | ||
ll = np.linspace(0, 2 * np.pi, 101) | ||
head_x = np.cos(ll) * radius * 1.01 + x | ||
head_y = np.sin(ll) * radius * 1.01 + y | ||
|
||
mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius) | ||
clip_radius = radius * mask_scale | ||
|
||
outlines_dict = { | ||
"head": (head_x, head_y), | ||
"mask_pos": (mask_scale * head_x, mask_scale * head_y), | ||
"clip_radius": (clip_radius,) * 2, | ||
"clip_origin": clip_origin, | ||
} | ||
return outlines_dict |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import io | ||
|
||
import matplotlib.pyplot as plt | ||
import mne | ||
import numpy as np | ||
from mne.io import BaseRaw | ||
from mne.preprocessing import ICA | ||
from mne.utils import _validate_type, warn | ||
from numpy.typing import NDArray | ||
from PIL import Image | ||
from scipy import interpolate | ||
from scipy.spatial import ConvexHull | ||
|
||
from mne_icalabel.iclabel._utils import _pol2cart | ||
|
||
from ._utils import _cart2sph, _make_head_outlines | ||
|
||
|
||
def get_megnet_features(raw: BaseRaw, ica: ICA): | ||
"""Extract time series and topomaps for each ICA component. | ||
|
||
MEGNet uses topomaps from BrainStorm exported as 120x120x3 RGB images. | ||
Thus, we need to replicate the 'appearance'/'look' of a BrainStorm topomap. | ||
|
||
Parameters | ||
---------- | ||
raw : Raw. | ||
Raw MEG recording used to fit the ICA decomposition. | ||
The raw instance should be bandpass filtered between | ||
1 and 100 Hz and notch filtered at 50 or 60 Hz to | ||
remove line noise, and downsampled to 250 Hz. | ||
ica : ICA | ||
ICA decomposition of the provided instance. | ||
The ICA decomposition should use the infomax method. | ||
|
||
Returns | ||
------- | ||
time_series : array of shape (n_components, n_samples) | ||
The time series for each ICA component. | ||
topomaps : array of shape (n_components, 120, 120, 3) | ||
The topomap RGB images for each ICA component. | ||
""" | ||
_validate_type(raw, BaseRaw, "raw") | ||
_validate_type(ica, ICA, "ica") | ||
if not any( | ||
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True) | ||
): | ||
raise RuntimeError( | ||
"Could not find MEG channels in the provided Raw instance." | ||
"The MEGnet model was fitted on MEG data and is not" | ||
"suited for other types of channels." | ||
) | ||
if (n_samples := raw.get_data().shape[1]) < 15000: | ||
raise RuntimeError( | ||
f"The provided raw instance has {n_samples} points. " | ||
"MEGnet was designed to classify features extracted " | ||
"from an MEG dataset at least 60 seconds long @ 250 Hz," | ||
"corresponding to at least. 15 000 samples." | ||
) | ||
if not np.isclose(raw.info["sfreq"], 250, atol=1e-1): | ||
warn( | ||
"The provided raw instance is not sampled at 250 Hz " | ||
f"(sfreq={raw.info['sfreq']} Hz). " | ||
"MEGnet was designed to classify features extracted from" | ||
"an MEG dataset sampled at 250 Hz " | ||
"(see the 'resample()' method for Raw instances). " | ||
"The classification performance might be negatively impacted." | ||
) | ||
if raw.info["highpass"] != 1 or raw.info["lowpass"] != 100: | ||
warn( | ||
"The provided raw instance is not filtered between 1 and 100 Hz. " | ||
"MEGnet was designed to classify features extracted from an MEG " | ||
"dataset bandpass filtered between 1 and 100 Hz" | ||
" (see the 'filter()' method for Raw instances)." | ||
" The classification performance might be negatively impacted." | ||
) | ||
if _check_line_noise(raw): | ||
warn( | ||
"Line noise detected in 50/60 Hz. MEGnet was trained on" | ||
"MEG data without line noise. Please remove line noise" | ||
"before using MEGnet (see the 'notch_filter()' method" | ||
"for Raw instances)." | ||
) | ||
if ica.method != "infomax": | ||
warn( | ||
f"The provided ICA instance was fitted with '{ica.method}'." | ||
"MEGnet was designed with infomax method." | ||
"To use the it, set mne.preprocessing.ICA instance with " | ||
"the arguments ICA(method='infomax')." | ||
) | ||
if ica.n_components != 20: | ||
warn( | ||
f"The provided ICA instance has {ica.n_components} components. " | ||
"MEGnet was designed with 20 components. " | ||
"use mne.preprocessing.ICA instance with " | ||
"the arguments ICA(n_components=20)." | ||
) | ||
|
||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pos_new, outlines = _get_topomaps_data(ica) | ||
topomaps = _get_topomaps(ica, pos_new, outlines) | ||
time_series = ica.get_sources(raw).get_data() | ||
return time_series, topomaps | ||
|
||
|
||
def _get_topomaps_data(ica: ICA): | ||
"""Prepare 2D sensor positions and outlines for topomap plotting.""" | ||
mags = mne.pick_types(ica.info, meg="mag") | ||
channel_info = ica.info["chs"] | ||
loc_3d = [channel_info[i]["loc"][0:3] for i in mags] | ||
channel_locations_3d = np.array(loc_3d) | ||
|
||
# Convert to spherical and then to 2D | ||
sph_coords = np.transpose( | ||
_cart2sph( | ||
channel_locations_3d[:, 0], | ||
channel_locations_3d[:, 1], | ||
channel_locations_3d[:, 2], | ||
) | ||
) | ||
TH, PHI = sph_coords[:, 1], sph_coords[:, 2] | ||
newR = 1 - PHI / np.pi * 2 | ||
channel_locations_2d = np.transpose(_pol2cart(TH, newR)) | ||
|
||
# Adjust coordinates with convex hull interpolation | ||
hull = ConvexHull(channel_locations_2d) | ||
border_indices = hull.vertices | ||
Dborder = 1 / newR[border_indices] | ||
|
||
funcTh = np.hstack( | ||
[ | ||
TH[border_indices] - 2 * np.pi, | ||
TH[border_indices], | ||
TH[border_indices] + 2 * np.pi, | ||
] | ||
) | ||
funcD = np.hstack((Dborder, Dborder, Dborder)) | ||
interp_func = interpolate.interp1d(funcTh, funcD) | ||
D = interp_func(TH) | ||
|
||
adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))]) | ||
Xnew, Ynew = _pol2cart(TH, adjusted_R) | ||
pos_new = np.vstack((Xnew, Ynew)).T | ||
|
||
outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0)) | ||
return pos_new, outlines | ||
|
||
|
||
def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict): | ||
"""Generate topomap images for each ICA component.""" | ||
topomaps = [] | ||
data_picks = mne.pick_types(ica.info, meg="mag") | ||
components = ica.get_components() | ||
|
||
for comp in range(ica.n_components_): | ||
data = components[data_picks, comp] | ||
fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black") | ||
ax = fig.add_subplot(111) | ||
mnefig, _ = mne.viz.plot_topomap( | ||
data, | ||
pos_new, | ||
sensors=False, | ||
outlines=outlines, | ||
extrapolate="head", | ||
sphere=[0, 0, 0, 1], | ||
contours=0, | ||
res=120, | ||
axes=ax, | ||
show=False, | ||
cmap="bwr", | ||
) | ||
img_buf = io.BytesIO() | ||
mnefig.figure.savefig( | ||
img_buf, format="png", dpi=120, bbox_inches="tight", pad_inches=0 | ||
) | ||
img_buf.seek(0) | ||
rgba_image = Image.open(img_buf) | ||
rgb_image = rgba_image.convert("RGB") | ||
img_buf.close() | ||
plt.close(fig) | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
topomaps.append(np.array(rgb_image)) | ||
|
||
return np.array(topomaps) | ||
|
||
|
||
def _check_line_noise( | ||
raw: BaseRaw, *, neighbor_width: int = 4, threshold_factor: int = 10 | ||
) -> bool: | ||
"""Check if line noise is present in the MEG/EEG data.""" | ||
# we don't know the line frequency | ||
if raw.info.get("line_freq", None) is None: | ||
return False | ||
# validate the primary and first harmonic frequencies | ||
nyquist_freq = raw.info["sfreq"] / 2.0 | ||
line_freqs = [raw.info["line_freq"], 2 * raw.info["line_freq"]] | ||
if any(nyquist_freq < lf for lf in line_freqs): | ||
# not raising because if we get here, | ||
# it means that someone provided a raw with | ||
# a sampling rate extremely low (100 Hz?) and (1) | ||
# either they missed all of the previous warnings | ||
# encountered or (2) they know what they are doing. | ||
warn("The sampling rate raw.info['sfreq'] is too low" "to estimate line niose.") | ||
return False | ||
# compute the power spectrum and retrieve the frequencies of interest | ||
spectrum = raw.compute_psd(picks="meg", exclude="bads") | ||
data, freqs = spectrum.get_data( | ||
fmin=raw.info["line_freq"] - neighbor_width, | ||
fmax=raw.info["line_freq"] + neighbor_width, | ||
return_freqs=True, | ||
) # array of shape (n_good_channel, n_freqs) | ||
idx = np.argmin(np.abs(freqs - raw.info["line_freq"])) | ||
mask = np.ones(data.shape[1], dtype=bool) | ||
mask[idx] = False | ||
background_mean = np.mean(data[:, mask], axis=1) | ||
background_std = np.std(data[:, mask], axis=1) | ||
threshold = background_mean + threshold_factor * background_std | ||
return np.any(data[:, idx] > threshold) |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,107 @@ | ||||
from importlib.resources import files | ||||
|
||||
import numpy as np | ||||
import onnxruntime as ort | ||||
from mne.io import BaseRaw | ||||
from mne.preprocessing import ICA | ||||
from numpy.typing import NDArray | ||||
|
||||
from .features import get_megnet_features | ||||
|
||||
_MODEL_PATH: str = files("mne_icalabel.megnet") / "assets" / "megnet.onnx" | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this line causes issues. See:
Let us know if that doesn't work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think he modified anything from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes nvm. I misread the error. Hmm it seems the file is somehow not found. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I found the issue! After I added the |
||||
|
||||
|
||||
def megnet_label_components(raw: BaseRaw, ica: ICA) -> NDArray: | ||||
"""Label the provided ICA components with the MEGnet neural network. | ||||
|
||||
Parameters | ||||
---------- | ||||
raw : Raw | ||||
Raw MEG recording used to fit the ICA decomposition. | ||||
The raw instance should be bandpass filtered between 1 and 100 Hz | ||||
and notch filtered at 50 or 60 Hz to remove line noise, | ||||
and downsampled to 250 Hz. | ||||
ica : ICA | ||||
ICA decomposition of the provided instance. | ||||
The ICA decomposition should use the infomax method. | ||||
|
||||
Returns | ||||
------- | ||||
labels_pred_proba : numpy.ndarray of shape (n_components, n_classes) | ||||
The estimated corresponding predicted probabilities of output classes | ||||
for each independent component. Columns are ordered with | ||||
'brain/other', 'eye movement', 'heart beat', 'eye blink', | ||||
|
||||
""" | ||||
time_series, topomaps = get_megnet_features(raw, ica) | ||||
|
||||
# sanity-checks | ||||
# number of time-series <-> topos | ||||
assert time_series.shape[0] == topomaps.shape[0] | ||||
# topos are images of shape 120x120x3 | ||||
assert topomaps.shape[1:] == (120, 120, 3) | ||||
# minimum time-series length | ||||
assert 15000 <= time_series.shape[1] | ||||
|
||||
session = ort.InferenceSession(_MODEL_PATH) | ||||
labels_pred_proba = _chunk_predicting(session, time_series, topomaps) | ||||
return labels_pred_proba[:, 0, :] | ||||
|
||||
|
||||
def _chunk_predicting( | ||||
session: ort.InferenceSession, | ||||
time_series: NDArray, | ||||
spatial_maps: NDArray, | ||||
chunk_len=15000, | ||||
overlap_len=3750, | ||||
) -> NDArray: | ||||
"""MEGnet's chunk volte algorithm.""" | ||||
predction_vote = [] | ||||
|
||||
for comp_series, comp_map in zip(time_series, spatial_maps): | ||||
time_len = comp_series.shape[0] | ||||
start_times = _get_chunk_start(time_len, chunk_len, overlap_len) | ||||
|
||||
if start_times[-1] + chunk_len <= time_len: | ||||
start_times.append(time_len - chunk_len) | ||||
|
||||
chunk_votes = {start: 0 for start in start_times} | ||||
for t in range(time_len): | ||||
in_chunks = [start <= t < start + chunk_len for start in start_times] | ||||
# how many chunks the time point is in | ||||
num_chunks = np.sum(in_chunks) | ||||
for start_time, is_in_chunk in zip(start_times, in_chunks): | ||||
if is_in_chunk: | ||||
chunk_votes[start_time] += 1.0 / num_chunks | ||||
|
||||
weighted_predictions = {} | ||||
for start_time in chunk_votes.keys(): | ||||
onnx_inputs = { | ||||
session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype( | ||||
np.float32 | ||||
), | ||||
session.get_inputs()[1].name: np.expand_dims( | ||||
np.expand_dims(comp_series[start_time : start_time + chunk_len], 0), | ||||
-1, | ||||
).astype(np.float32), | ||||
} | ||||
prediction = session.run(None, onnx_inputs)[0] | ||||
weighted_predictions[start_time] = prediction * chunk_votes[start_time] | ||||
|
||||
comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0) | ||||
comp_prediction /= comp_prediction.sum() | ||||
predction_vote.append(comp_prediction) | ||||
|
||||
return np.stack(predction_vote) | ||||
|
||||
|
||||
def _get_chunk_start( | ||||
input_len: int, chunk_len: int = 15000, overlap_len: int = 3750 | ||||
) -> list: | ||||
"""Calculate start times for time series chunks with overlap.""" | ||||
start_times = [] | ||||
start_time = 0 | ||||
while start_time + chunk_len <= input_len: | ||||
start_times.append(start_time) | ||||
start_time += chunk_len - overlap_len | ||||
return start_times |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.