Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type overloads to Epochs.average() #12769

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions doc/changes/devel/12769.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
When creating :class:`~mne.Evoked` by averaging :class:`mne.Epochs` via the :meth:`~mne.Epochs.average`
method, static analysis tools like Pylance will now correctly infer whether a list of :class:`~mne.EvokedArray`
or a single :class:`~mne.EvokedArray` is returned that a `pathlib.Path`, enabling better editor support like
automated code completions on the returned object, by `Richard Höchenberger`_.
56 changes: 45 additions & 11 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from __future__ import annotations

import json
import operator
import os.path as op
Expand All @@ -18,8 +20,10 @@
from functools import partial
from inspect import getfullargspec
from pathlib import Path
from typing import Callable, Literal, overload

import numpy as np
from numpy.typing import NDArray
from scipy.interpolate import interp1d

from ._fiff.constants import FIFF
Expand Down Expand Up @@ -1070,21 +1074,47 @@ def subtract_evoked(self, evoked=None):

return self

@overload
def average(
self,
picks=None,
method: Literal["mean", "median"]
| Callable[[NDArray[np.float64]], NDArray[np.float64]] = "mean",
by_event_type: Literal[False] = False,
) -> EvokedArray: ...

@overload
def average(
self,
picks=None,
method: Literal["mean", "median"]
| Callable[[NDArray[np.float64]], NDArray[np.float64]] = "mean",
by_event_type: Literal[True] = True,
) -> list[EvokedArray]: ...

@fill_doc
def average(self, picks=None, method="mean", by_event_type=False):
def average(
self,
picks=None,
method: Literal["mean", "median"]
| Callable[[NDArray[np.float64]], NDArray[np.float64]] = "mean",
by_event_type: bool = False,
) -> EvokedArray | list[EvokedArray]:
"""Compute an average over epochs.

Parameters
----------
%(picks_all_data)s
method : str | callable
How to combine the data. If "mean"/"median", the mean/median
are returned.
Otherwise, must be a callable which, when passed an array of shape
(n_epochs, n_channels, n_time) returns an array of shape
(n_channels, n_time).
Note that due to file type limitations, the kind for all
these will be "average".
method : "mean" | "median" | callable
How to average the data across epochs, time-point by time-point.
Pass ``"mean"`` for the arithmeic mean, or ``"median"`` for the median.

.. note:: In typical ERP and ERF analyses, ``"mean"`` (the default) should
be used.

Can also be a function accepting an array of shape
``(n_epochs, n_channels, n_time)`` and returning an array of shape
``(n_channels, n_time)`` (i.e., collapsing the epochs dimensioon).
%(by_event_type)s

Returns
Expand All @@ -1097,7 +1127,7 @@ def average(self, picks=None, method="mean", by_event_type=False):
they correspond to different conditions. To average by condition,
do ``epochs[condition].average()`` for each condition separately.

When picks is None and epochs contain only ICA channels, no channels
When ``picks`` is ``None`` and epochs contain only ICA channels, no channels
are selected, resulting in an error. This is because ICA channels
are not considered data channels (they are of misc type) and only data
channels are selected when picks is None.
Expand All @@ -1110,6 +1140,10 @@ def average(self, picks=None, method="mean", by_event_type=False):
>>> epochs.average(method=trim) # doctest:+SKIP

This would compute the trimmed mean.

The "kind" for all these operations, including custom functions, will be
labeled as "average" when writing the data to disk, due to limitations of the
FIFF file format.
"""
self._handle_empty("raise", "average")
if by_event_type:
Expand Down Expand Up @@ -4199,7 +4233,7 @@ def _read_one_epoch_file(f, tree, preload):


@verbose
def read_epochs(fname, proj=True, preload=True, verbose=None) -> "EpochsFIF":
def read_epochs(fname, proj=True, preload=True, verbose=None) -> EpochsFIF:
"""Read epochs from a fif file.

Parameters
Expand Down
Loading