From dd2aa12dbd61039ec452128cab5d8700132d1de7 Mon Sep 17 00:00:00 2001 From: usuario1marc Date: Mon, 18 May 2026 13:21:03 +0200 Subject: [PATCH] Added option for bipolar re-referencing before analysis --- erdetect/_erdetect.py | 203 +++++++++++++++++++++++++++++----------- erdetect/core/config.py | 5 + erdetect/utils/misc.py | 72 ++++++++++++++ erdetect/version.py | 2 +- 4 files changed, 228 insertions(+), 54 deletions(-) diff --git a/erdetect/_erdetect.py b/erdetect/_erdetect.py index 5f37105..1585cab 100755 --- a/erdetect/_erdetect.py +++ b/erdetect/_erdetect.py @@ -20,7 +20,7 @@ from os.path import exists from ieegprep.bids.sidecars import load_channel_info, load_elec_stim_events, load_ieeg_sidecar -from ieegprep.bids.data_epoch import load_data_epochs_averages +from ieegprep.bids.data_epoch import load_data_epochs_averages, load_data_epochs from ieegprep.bids.rereferencing import RerefStruct from ieegprep.utils.console import multi_line_list, print_progressbar from ieegprep.utils.misc import is_number @@ -29,7 +29,7 @@ from erdetect.core.config import write_config, get as cfg, get_config_dict, OUTPUT_IMAGE_SIZE, LOGGING_CAPTION_INDENT_LENGTH from erdetect.core.detection import ieeg_detect_er from erdetect.views.output_images import calc_sizes_and_fonts, calc_matrix_image_size, gen_amplitude_matrix, gen_latency_matrix -from erdetect.utils.misc import create_figure +from erdetect.utils.misc import create_figure, mount_bipolar from erdetect.core.metrics.metric_cross_proj import MetricCrossProj from erdetect.core.metrics.metric_waveform import MetricWaveform @@ -70,7 +70,14 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F return # determine a subset specific output path - output_root = os.path.join(output_dir, os.path.basename(os.path.normpath(bids_subset_root))) + try: + # this expects the input file to follow the iEEG BIDS naming conventions + output_root = os.path.join(output_dir, + os.path.basename(bids_subset_data_path).split("_")[0], + os.path.basename(bids_subset_data_path).split("_")[1], + os.path.basename(bids_subset_data_path).split("_")[3]) + except IndexError: + output_root = os.path.join(output_dir, os.path.basename(os.path.normpath(bids_subset_root))) # make sure the subject output directory exists if not os.path.exists(output_root): @@ -408,9 +415,10 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F if cfg('metrics', 'waveform', 'enabled'): metric_callbacks += tuple([MetricWaveform.process_callback]) + # determine which channel montage (monopolar or bipolar) should be used + bipolar_rereferencing = cfg('preprocess', 'bipolar_montage') + # read, normalize, epoch and average the trials within the condition - # Note: 'load_data_epochs_averages' is used instead of 'load_data_epochs' here because it is more memory - # efficient when only the averages are needed if len(metric_callbacks) == 0: logging.info('- Reading data...') else: @@ -419,57 +427,134 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F # TODO: normalize to raw or to Z-values (return both raw and z?) # z-might be needed for detection try: - sampling_rate, averages, metrics = load_data_epochs_averages(bids_subset_data_path, channels_measured_incl, stim_pairs_onsets, - trial_epoch=cfg('trials', 'trial_epoch'), - baseline_norm=cfg('trials', 'baseline_norm'), - baseline_epoch=cfg('trials', 'baseline_epoch'), - out_of_bound_handling=cfg('trials', 'out_of_bounds_handling'), - metric_callbacks=metric_callbacks, - high_pass=cfg('preprocess', 'high_pass'), - early_reref=early_reref, - line_noise_removal=line_noise_removal, - late_reref=late_reref, - preproc_priority=('speed' if preproc_prioritize_speed else 'mem')) + if bipolar_rereferencing: + sampling_rate, data = load_data_epochs(bids_subset_data_path, channels_measured_incl, trial_onsets, + trial_epoch=cfg('trials', 'trial_epoch'), + baseline_norm=cfg('trials', 'baseline_norm'), + baseline_epoch=cfg('trials', 'baseline_epoch'), + out_of_bound_handling=cfg('trials', 'out_of_bounds_handling'), + high_pass=cfg('preprocess', 'high_pass'), + early_reref=early_reref, + line_noise_removal=line_noise_removal, + late_reref=late_reref, + preproc_priority=('speed' if preproc_prioritize_speed else 'mem')) + bipolar_data, bipolar_names = mount_bipolar(data, channels_measured_incl) + n_bipolar_channels, n_trials, n_times = bipolar_data.shape + else: + # Note: 'load_data_epochs_averages' is used instead of 'load_data_epochs' here because it is more memory + # efficient when only the averages are needed + sampling_rate, averages, metrics = load_data_epochs_averages(bids_subset_data_path, channels_measured_incl, stim_pairs_onsets, + trial_epoch=cfg('trials', 'trial_epoch'), + baseline_norm=cfg('trials', 'baseline_norm'), + baseline_epoch=cfg('trials', 'baseline_epoch'), + out_of_bound_handling=cfg('trials', 'out_of_bounds_handling'), + metric_callbacks=metric_callbacks, + high_pass=cfg('preprocess', 'high_pass'), + early_reref=early_reref, + line_noise_removal=line_noise_removal, + late_reref=late_reref, + preproc_priority=('speed' if preproc_prioritize_speed else 'mem')) except (ValueError, RuntimeError): logging.error('Could not load data (' + bids_subset_data_path + '), exiting...') raise RuntimeError('Could not load data') - - # split out the metric results - cross_proj_metrics = None - waveform_metrics = None - metric_counter = 0 - if cfg('metrics', 'cross_proj', 'enabled'): - cross_proj_metrics = np.array(metrics[:, :, metric_counter].tolist()) - metric_counter += 1 - if cfg('metrics', 'waveform', 'enabled'): - waveform_metrics = np.array(metrics[:, :, metric_counter].tolist()) - metric_counter += 1 - - # for each stimulation pair condition, NaN out the values of the measured electrodes that were stimulated - for stim_pair_index, stim_pair in enumerate(stim_pairs_onsets): - stim_pair_electrode_names = stim_pair.split('-') - - # find and clear the first electrode - try: - averages[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan - if cfg('metrics', 'cross_proj', 'enabled'): - cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan - if cfg('metrics', 'waveform', 'enabled'): - waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index] = np.nan - - except ValueError: - pass - - # find and clear the second electrode - try: - averages[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan - if cfg('metrics', 'cross_proj', 'enabled'): - cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan - if cfg('metrics', 'waveform', 'enabled'): - waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index] = np.nan - - except ValueError: - pass + + # Retrive trial and baseline epochs from config + trial_epoch = cfg('trials', 'trial_epoch') + baseline_epoch = cfg('trials', 'baseline_epoch') + trial_start = trial_epoch[0] + baseline_start_sample = int(round((baseline_epoch[0] - trial_start) * sampling_rate)) + baseline_end_sample = int(round((baseline_epoch[1] - trial_start) * sampling_rate)) + + if bipolar_rereferencing: + # obtain the metrics from individual traials since we don't use averages in bipolar montage + cross_proj_metrics = np.zeros((n_bipolar_channels, len(stim_pairs_onsets), 3)) + waveform_metrics = np.zeros((n_bipolar_channels, len(stim_pairs_onsets), 1)) + for channel in range(n_bipolar_channels): + for sp_idx, (stimpair, onsets) in enumerate(stim_pairs_onsets.items()): + indices = [i for i, p in enumerate(trial_pairs) if str(p[0]+'-'+p[1]) == stimpair] + epoch_data = bipolar_data[channel, indices, :] + epoch_baseline = epoch_data[:, baseline_start_sample:baseline_end_sample] + + if cfg('metrics', 'cross_proj', 'enabled'): + metric = MetricCrossProj.process_callback( + sampling_rate, + epoch_data, + epoch_baseline + ) + cross_proj_metrics[channel, sp_idx, :] = metric + + if cfg('metrics', 'waveform', 'enabled'): + metric = MetricWaveform.process_callback(sampling_rate, epoch_data, epoch_baseline) + waveform_metrics[channel, sp_idx, :] = metric + + # compute averages, but in bipolar + averages = np.zeros((n_bipolar_channels, len(stim_pairs_onsets), n_times)) + stimpair_indices = {} + + for channel in range(n_bipolar_channels): + for sp_idx, (stimpair, onsets) in enumerate(stim_pairs_onsets.items()): + indices = [i for i, p in enumerate(trial_pairs) if str(p[0]+'-'+p[1]) == stimpair] + stimpair_indices[stimpair] = indices + + epoch_data = bipolar_data[channel, indices, :] + averages[channel, sp_idx, :] = np.nanmean(epoch_data, axis=0) + + # for each stimulation pair condition, NaN out the values of the measured electrodes that were stimulated + for stim_pair_index, stim_pair in enumerate(stim_pairs_onsets): + stim_ch1, stim_ch2 = stim_pair.split('-') + + for ch_idx, ch_name in enumerate(bipolar_names): + if stim_ch1 in ch_name or stim_ch2 in ch_name: + try: + averages[ch_idx, stim_pair_index, :] = np.nan + if cfg('metrics', 'cross_proj', 'enabled'): + cross_proj_metrics[ch_idx, stim_pair_index, :] = np.nan + if cfg('metrics', 'waveform', 'enabled'): + waveform_metrics[ch_idx, stim_pair_index] = np.nan + + except ValueError: + pass + + # channel names are now bipolar + channels_measured_incl = bipolar_names + + else: + # split out the metric results + cross_proj_metrics = None + waveform_metrics = None + metric_counter = 0 + if cfg('metrics', 'cross_proj', 'enabled'): + cross_proj_metrics = np.array(metrics[:, :, metric_counter].tolist()) + metric_counter += 1 + if cfg('metrics', 'waveform', 'enabled'): + waveform_metrics = np.array(metrics[:, :, metric_counter].tolist()) + metric_counter += 1 + + # for each stimulation pair condition, NaN out the values of the measured electrodes that were stimulated + for stim_pair_index, stim_pair in enumerate(stim_pairs_onsets): + stim_pair_electrode_names = stim_pair.split('-') + + # find and clear the first electrode + try: + averages[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan + if cfg('metrics', 'cross_proj', 'enabled'): + cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan + if cfg('metrics', 'waveform', 'enabled'): + waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index] = np.nan + + except ValueError: + pass + + # find and clear the second electrode + try: + averages[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan + if cfg('metrics', 'cross_proj', 'enabled'): + cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan + if cfg('metrics', 'waveform', 'enabled'): + waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index] = np.nan + + except ValueError: + pass # determine the sample of stimulus onset (counting from the epoch start) onset_sample = int(round(abs(cfg('trials', 'trial_epoch')[0] * sampling_rate))) @@ -492,6 +577,18 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F MetricCrossProj.append_output_dict_callback(output_dict, cross_proj_metrics) if cfg('metrics', 'waveform', 'enabled'): MetricWaveform.append_output_dict_callback(output_dict, waveform_metrics) + + if cfg('preprocess', 'bipolar_montage'): + output_dict['ccep_data'] = bipolar_data + output_dict['stimpair_indices'] = stimpair_indices + + try: + # this assumes that the input file follows the standard iEEG BIDS naming structure + output_dict['subject'] = os.path.basename(bids_subset_data_path).split("_")[0][4:] + output_dict['session'] = os.path.basename(bids_subset_data_path).split("_")[1][4:] + output_dict['run'] = os.path.basename(bids_subset_data_path).split("_")[3][4:] + except IndexError: + pass sio.savemat(os.path.join(output_root, 'erdetect_data.mat'), output_dict) diff --git a/erdetect/core/config.py b/erdetect/core/config.py index 9606603..0f9b636 100755 --- a/erdetect/core/config.py +++ b/erdetect/core/config.py @@ -55,6 +55,7 @@ def create_default_config(): config['preprocess']['late_re_referencing']['CAR_by_variance'] = -1 # config['preprocess']['late_re_referencing']['stim_excl_epoch'] = (-1.0, 2.0) config['preprocess']['late_re_referencing']['channel_types'] = ('ECOG', 'SEEG', 'DBS') # the type of channels that will be included for late re-referencing + config['preprocess']['bipolar_montage'] = False # if True, channels will be re-referenced to a bipolar setting config['trials'] = dict() config['trials']['trial_epoch'] = (-1.0, 2.0) # the time-span (in seconds) relative to the stimulus onset that will be used to extract the signal for each trial @@ -385,6 +386,9 @@ def retrieve_config_tuple(json_dict, ref_config, level1, level2, level3=None, op return False if not retrieve_config_string(json_config, config, 'preprocess', 'late_re_referencing', 'method', options=('CAR', 'CAR_headbox')): return False + + if not retrieve_config_bool(json_config, config, 'preprocess', 'bipolar_montage'): + return False if config['preprocess']['late_re_referencing']['method'] in ('CAR', 'CAR_headbox'): retrieve_config_number(json_config, config, 'preprocess', 'late_re_referencing', 'CAR_by_variance') @@ -533,6 +537,7 @@ def write_config(filepath): ' "channel_types": ' + json.dumps(_config['preprocess']['early_re_referencing']['channel_types']) + '\n' \ ' },\n' \ ' "line_noise_removal": "' + _config['preprocess']['line_noise_removal'] + '",\n' \ + ' "bipolar_montage": ' + ('true' if _config['preprocess']['bipolar_montage'] else 'false') + ',\n' \ ' "late_re_referencing": {\n' \ ' "enabled": ' + ('true' if _config['preprocess']['late_re_referencing']['enabled'] else 'false') + ',\n' \ ' "method": "' + _config['preprocess']['late_re_referencing']['method'] + '",\n' diff --git a/erdetect/utils/misc.py b/erdetect/utils/misc.py index ee31b62..f2d6a62 100644 --- a/erdetect/utils/misc.py +++ b/erdetect/utils/misc.py @@ -17,6 +17,10 @@ from matplotlib.figure import Figure from ieegprep.utils.misc import is_number +import re +import numpy as np +from collections import defaultdict + def create_figure(width=500, height=500, on_screen=False): """ @@ -116,3 +120,71 @@ def numbers_to_padded_string(values, width=0, pos_space=True, separator=', '): padded_str += padded_values[iValue] return padded_str + +def mount_bipolar(monopolar_data_epoched:np.ndarray, monopolar_channels:list, sep:str='-'): + """ + Computes a bipolar montage of the input data. + + Channels are assumed to follow a naming convention: (e.g., A1, A2, B10). + + Adjacent channels are those with: + - same prefix + - consecutive numbers + + Parameters + ---------- + + monopolar_data_epoched : ndarray (n_channels, n_epochs, n_samples) + Input monopolar data array. + + monopolar_channels : list of str (n_channels) + List of input channel names. + + sep : str, optional + Divider that separates channels in bipolar channel names. + + Returns + ------- + + bipolar_data : ndarray (n_bipolar_channels, n_samples) + Output bipolar data array. + + bipolar_channels : list of str (n_bipolar_channels) + List of output channel names. + """ + + bipolar_data_epoched = [] + bipolar_channels = [] + + ch_idx = {ch: i for i, ch in enumerate(monopolar_channels)} + + # Parse channels into their prefix + their number + parsed = [] + pattern = re.compile(r"^([A-Za-z]+)(\d+)$") + for ch in monopolar_channels: + match = pattern.match(ch) + if match: + prefix, num = match.groups() + parsed.append((prefix, int(num), ch)) + + # Group channels by prefix + groups = defaultdict(list) + for prefix, num, ch in parsed: + groups[prefix].append((num, ch)) + + # Compute bipolar channel by grouping + for prefix, items in groups.items(): + items.sort(key=lambda x: x[0]) + for (n1, ch1), (n2, ch2) in zip(items[:-1], items[1:]): + if n2 > n1: + idx1 = ch_idx[ch1] + idx2 = ch_idx[ch2] + bipolar_data_epoched.append(monopolar_data_epoched[idx1,:,:] - monopolar_data_epoched[idx2,:,:]) + bipolar_channels.append(f"{ch1}{sep}{ch2}") + + if not bipolar_data_epoched: + raise RuntimeError("No bipolar channels created.") + + bipolar_data_epoched = np.stack(bipolar_data_epoched, axis=0) + + return bipolar_data_epoched, bipolar_channels \ No newline at end of file diff --git a/erdetect/version.py b/erdetect/version.py index b482efe..2614ce9 100755 --- a/erdetect/version.py +++ b/erdetect/version.py @@ -1 +1 @@ -__version__ = "2.6.2" +__version__ = "2.7.0"