diff --git a/meegkit/asr.py b/meegkit/asr.py index 29ce33ca..f33a7b56 100755 --- a/meegkit/asr.py +++ b/meegkit/asr.py @@ -21,14 +21,12 @@ class ASR: component-based artifact removal method for removing transient or large-amplitude artifacts in multi-channel EEG recordings [1]_. + The key parameter of the method is ``cutoff``. + Parameters ---------- sfreq : float Sampling rate of the data, in Hz. - - The following are optional parameters (the key parameter of the method is - the ``cutoff``): - cutoff: float Standard deviation cutoff for rejection. X portions whose variance is larger than this threshold relative to the calibration data are @@ -58,16 +56,16 @@ class ASR: method : {'riemann', 'euclid'} Method to use. If riemann, use the riemannian-modified version of ASR [2]_. - memory : float - Memory size (s), regulates the number of covariance matrices to store. - estimator : str in {'scm', 'lwf', 'oas', 'mcd'} + memory : float | None + Memory size (samples), regulates the number of covariance matrices to + store. + If None (default), will use twice the sampling frequency. + estimator : {'scm', 'lwf', 'oas', 'mcd'} Covariance estimator (default: 'scm' which computes the sample covariance). Use 'lwf' if you need regularization (requires pyriemann). Attributes ---------- - ``state_`` : dict - Initial state of the ASR filter. ``zi_``: array, shape=(n_channels, filter_order) Filter initial conditions. ``ab_``: 2-tuple @@ -98,9 +96,9 @@ class ASR: """ - def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5, + def __init__(self, *, sfreq=250, cutoff=5, blocksize=100, win_len=0.5, win_overlap=0.66, max_dropout_fraction=0.1, - min_clean_fraction=0.25, name="asrfilter", method="euclid", + min_clean_fraction=0.25, method="euclid", memory=None, estimator="scm", **kwargs): if pyriemann is None and method == "riemann": @@ -115,7 +113,10 @@ def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5, self.min_clean_fraction = min_clean_fraction self.max_bad_chans = 0.3 self.method = method - self.memory = int(2 * sfreq) # smoothing window for covariances + if memory is None: + self.memory = int(2 * sfreq) # smoothing window for covariances + else: + self.memory = memory self.sample_weight = np.geomspace(0.05, 1, num=self.memory + 1) self.sfreq = sfreq self.estimator = estimator @@ -141,10 +142,10 @@ def fit(self, X, y=None, **kwargs): """Calibration for the Artifact Subspace Reconstruction method. The input to this data is a multi-channel time series of calibration - data. In typical uses the calibration data is clean resting EEG data of - data if the fraction of artifact content is below the breakdown point + data. In typical uses the calibration data is clean resting EEG data. + The fraction of artifact content should be below the breakdown point of the robust statistics used for estimation (50% theoretical, ~30% - practical). If the data has a proportion of more than 30-50% artifacts + practical). If the data has a proportion of more than 30-50% artifacts, then bad time windows should be removed beforehand. This data is used to estimate the thresholds that are used by the ASR processing function to identify and remove artifact components. @@ -164,6 +165,12 @@ def fit(self, X, y=None, **kwargs): reasonably clean not less than 30 seconds (this method is typically used with 1 minute or more). + Returns + ------- + clean : array, shape=(n_channels, n_samples) + Dataset with bad time periods removed. + sample_mask : boolean array, shape=(1, n_samples) + Mask of retained samples (logical array). """ if X.ndim == 3: X = X.squeeze() @@ -468,6 +475,9 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5, estimation (default=0.25). method : {'euclid', 'riemann'} Metric to compute the covariance matrix average. + estimator : {'scm', 'lwf', 'oas', 'mcd'} + Covariance estimator (default: 'scm' which computes the sample + covariance). Use 'lwf' if you need regularization (requires pyriemann). Returns ------- diff --git a/tests/test_asr.py b/tests/test_asr.py index f213939d..f417c2e7 100644 --- a/tests/test_asr.py +++ b/tests/test_asr.py @@ -193,7 +193,7 @@ def test_asr_class(method, reref, show=False): blah = ASR(method=method, estimator="scm") blah.fit(raw2[:, train_idx]) - asr = ASR(method=method, estimator="lwf") + asr = ASR(method=method, estimator="lwf", memory=int(2 * sfreq)) asr.fit(raw2[:, train_idx]) else: asr = ASR(method=method, estimator="scm")