Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 150 additions & 53 deletions erdetect/_erdetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from os.path import exists

from ieegprep.bids.sidecars import load_channel_info, load_elec_stim_events, load_ieeg_sidecar
from ieegprep.bids.data_epoch import load_data_epochs_averages
from ieegprep.bids.data_epoch import load_data_epochs_averages, load_data_epochs
from ieegprep.bids.rereferencing import RerefStruct
from ieegprep.utils.console import multi_line_list, print_progressbar
from ieegprep.utils.misc import is_number
Expand All @@ -29,7 +29,7 @@
from erdetect.core.config import write_config, get as cfg, get_config_dict, OUTPUT_IMAGE_SIZE, LOGGING_CAPTION_INDENT_LENGTH
from erdetect.core.detection import ieeg_detect_er
from erdetect.views.output_images import calc_sizes_and_fonts, calc_matrix_image_size, gen_amplitude_matrix, gen_latency_matrix
from erdetect.utils.misc import create_figure
from erdetect.utils.misc import create_figure, mount_bipolar
from erdetect.core.metrics.metric_cross_proj import MetricCrossProj
from erdetect.core.metrics.metric_waveform import MetricWaveform

Expand Down Expand Up @@ -70,7 +70,14 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F
return

# determine a subset specific output path
output_root = os.path.join(output_dir, os.path.basename(os.path.normpath(bids_subset_root)))
try:
# this expects the input file to follow the iEEG BIDS naming conventions
output_root = os.path.join(output_dir,
os.path.basename(bids_subset_data_path).split("_")[0],
os.path.basename(bids_subset_data_path).split("_")[1],
os.path.basename(bids_subset_data_path).split("_")[3])
except IndexError:
output_root = os.path.join(output_dir, os.path.basename(os.path.normpath(bids_subset_root)))

# make sure the subject output directory exists
if not os.path.exists(output_root):
Expand Down Expand Up @@ -408,9 +415,10 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F
if cfg('metrics', 'waveform', 'enabled'):
metric_callbacks += tuple([MetricWaveform.process_callback])

# determine which channel montage (monopolar or bipolar) should be used
bipolar_rereferencing = cfg('preprocess', 'bipolar_montage')

# read, normalize, epoch and average the trials within the condition
# Note: 'load_data_epochs_averages' is used instead of 'load_data_epochs' here because it is more memory
# efficient when only the averages are needed
if len(metric_callbacks) == 0:
logging.info('- Reading data...')
else:
Expand All @@ -419,57 +427,134 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F
# TODO: normalize to raw or to Z-values (return both raw and z?)
# z-might be needed for detection
try:
sampling_rate, averages, metrics = load_data_epochs_averages(bids_subset_data_path, channels_measured_incl, stim_pairs_onsets,
trial_epoch=cfg('trials', 'trial_epoch'),
baseline_norm=cfg('trials', 'baseline_norm'),
baseline_epoch=cfg('trials', 'baseline_epoch'),
out_of_bound_handling=cfg('trials', 'out_of_bounds_handling'),
metric_callbacks=metric_callbacks,
high_pass=cfg('preprocess', 'high_pass'),
early_reref=early_reref,
line_noise_removal=line_noise_removal,
late_reref=late_reref,
preproc_priority=('speed' if preproc_prioritize_speed else 'mem'))
if bipolar_rereferencing:
sampling_rate, data = load_data_epochs(bids_subset_data_path, channels_measured_incl, trial_onsets,
trial_epoch=cfg('trials', 'trial_epoch'),
baseline_norm=cfg('trials', 'baseline_norm'),
baseline_epoch=cfg('trials', 'baseline_epoch'),
out_of_bound_handling=cfg('trials', 'out_of_bounds_handling'),
high_pass=cfg('preprocess', 'high_pass'),
early_reref=early_reref,
line_noise_removal=line_noise_removal,
late_reref=late_reref,
preproc_priority=('speed' if preproc_prioritize_speed else 'mem'))
bipolar_data, bipolar_names = mount_bipolar(data, channels_measured_incl)
n_bipolar_channels, n_trials, n_times = bipolar_data.shape
else:
# Note: 'load_data_epochs_averages' is used instead of 'load_data_epochs' here because it is more memory
# efficient when only the averages are needed
sampling_rate, averages, metrics = load_data_epochs_averages(bids_subset_data_path, channels_measured_incl, stim_pairs_onsets,
trial_epoch=cfg('trials', 'trial_epoch'),
baseline_norm=cfg('trials', 'baseline_norm'),
baseline_epoch=cfg('trials', 'baseline_epoch'),
out_of_bound_handling=cfg('trials', 'out_of_bounds_handling'),
metric_callbacks=metric_callbacks,
high_pass=cfg('preprocess', 'high_pass'),
early_reref=early_reref,
line_noise_removal=line_noise_removal,
late_reref=late_reref,
preproc_priority=('speed' if preproc_prioritize_speed else 'mem'))
except (ValueError, RuntimeError):
logging.error('Could not load data (' + bids_subset_data_path + '), exiting...')
raise RuntimeError('Could not load data')

# split out the metric results
cross_proj_metrics = None
waveform_metrics = None
metric_counter = 0
if cfg('metrics', 'cross_proj', 'enabled'):
cross_proj_metrics = np.array(metrics[:, :, metric_counter].tolist())
metric_counter += 1
if cfg('metrics', 'waveform', 'enabled'):
waveform_metrics = np.array(metrics[:, :, metric_counter].tolist())
metric_counter += 1

# for each stimulation pair condition, NaN out the values of the measured electrodes that were stimulated
for stim_pair_index, stim_pair in enumerate(stim_pairs_onsets):
stim_pair_electrode_names = stim_pair.split('-')

# find and clear the first electrode
try:
averages[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan
if cfg('metrics', 'cross_proj', 'enabled'):
cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan
if cfg('metrics', 'waveform', 'enabled'):
waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index] = np.nan

except ValueError:
pass

# find and clear the second electrode
try:
averages[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan
if cfg('metrics', 'cross_proj', 'enabled'):
cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan
if cfg('metrics', 'waveform', 'enabled'):
waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index] = np.nan

except ValueError:
pass

# Retrive trial and baseline epochs from config
trial_epoch = cfg('trials', 'trial_epoch')
baseline_epoch = cfg('trials', 'baseline_epoch')
trial_start = trial_epoch[0]
baseline_start_sample = int(round((baseline_epoch[0] - trial_start) * sampling_rate))
baseline_end_sample = int(round((baseline_epoch[1] - trial_start) * sampling_rate))

if bipolar_rereferencing:
# obtain the metrics from individual traials since we don't use averages in bipolar montage
cross_proj_metrics = np.zeros((n_bipolar_channels, len(stim_pairs_onsets), 3))
waveform_metrics = np.zeros((n_bipolar_channels, len(stim_pairs_onsets), 1))
for channel in range(n_bipolar_channels):
for sp_idx, (stimpair, onsets) in enumerate(stim_pairs_onsets.items()):
indices = [i for i, p in enumerate(trial_pairs) if str(p[0]+'-'+p[1]) == stimpair]
epoch_data = bipolar_data[channel, indices, :]
epoch_baseline = epoch_data[:, baseline_start_sample:baseline_end_sample]

if cfg('metrics', 'cross_proj', 'enabled'):
metric = MetricCrossProj.process_callback(
sampling_rate,
epoch_data,
epoch_baseline
)
cross_proj_metrics[channel, sp_idx, :] = metric

if cfg('metrics', 'waveform', 'enabled'):
metric = MetricWaveform.process_callback(sampling_rate, epoch_data, epoch_baseline)
waveform_metrics[channel, sp_idx, :] = metric

# compute averages, but in bipolar
averages = np.zeros((n_bipolar_channels, len(stim_pairs_onsets), n_times))
stimpair_indices = {}

for channel in range(n_bipolar_channels):
for sp_idx, (stimpair, onsets) in enumerate(stim_pairs_onsets.items()):
indices = [i for i, p in enumerate(trial_pairs) if str(p[0]+'-'+p[1]) == stimpair]
stimpair_indices[stimpair] = indices

epoch_data = bipolar_data[channel, indices, :]
averages[channel, sp_idx, :] = np.nanmean(epoch_data, axis=0)

# for each stimulation pair condition, NaN out the values of the measured electrodes that were stimulated
for stim_pair_index, stim_pair in enumerate(stim_pairs_onsets):
stim_ch1, stim_ch2 = stim_pair.split('-')

for ch_idx, ch_name in enumerate(bipolar_names):
if stim_ch1 in ch_name or stim_ch2 in ch_name:
try:
averages[ch_idx, stim_pair_index, :] = np.nan
if cfg('metrics', 'cross_proj', 'enabled'):
cross_proj_metrics[ch_idx, stim_pair_index, :] = np.nan
if cfg('metrics', 'waveform', 'enabled'):
waveform_metrics[ch_idx, stim_pair_index] = np.nan

except ValueError:
pass

# channel names are now bipolar
channels_measured_incl = bipolar_names

else:
# split out the metric results
cross_proj_metrics = None
waveform_metrics = None
metric_counter = 0
if cfg('metrics', 'cross_proj', 'enabled'):
cross_proj_metrics = np.array(metrics[:, :, metric_counter].tolist())
metric_counter += 1
if cfg('metrics', 'waveform', 'enabled'):
waveform_metrics = np.array(metrics[:, :, metric_counter].tolist())
metric_counter += 1

# for each stimulation pair condition, NaN out the values of the measured electrodes that were stimulated
for stim_pair_index, stim_pair in enumerate(stim_pairs_onsets):
stim_pair_electrode_names = stim_pair.split('-')

# find and clear the first electrode
try:
averages[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan
if cfg('metrics', 'cross_proj', 'enabled'):
cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index, :] = np.nan
if cfg('metrics', 'waveform', 'enabled'):
waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[0]), stim_pair_index] = np.nan

except ValueError:
pass

# find and clear the second electrode
try:
averages[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan
if cfg('metrics', 'cross_proj', 'enabled'):
cross_proj_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index, :] = np.nan
if cfg('metrics', 'waveform', 'enabled'):
waveform_metrics[channels_measured_incl.index(stim_pair_electrode_names[1]), stim_pair_index] = np.nan

except ValueError:
pass

# determine the sample of stimulus onset (counting from the epoch start)
onset_sample = int(round(abs(cfg('trials', 'trial_epoch')[0] * sampling_rate)))
Expand All @@ -492,6 +577,18 @@ def process_subset(bids_subset_data_path, output_dir, preproc_prioritize_speed=F
MetricCrossProj.append_output_dict_callback(output_dict, cross_proj_metrics)
if cfg('metrics', 'waveform', 'enabled'):
MetricWaveform.append_output_dict_callback(output_dict, waveform_metrics)

if cfg('preprocess', 'bipolar_montage'):
output_dict['ccep_data'] = bipolar_data
output_dict['stimpair_indices'] = stimpair_indices

try:
# this assumes that the input file follows the standard iEEG BIDS naming structure
output_dict['subject'] = os.path.basename(bids_subset_data_path).split("_")[0][4:]
output_dict['session'] = os.path.basename(bids_subset_data_path).split("_")[1][4:]
output_dict['run'] = os.path.basename(bids_subset_data_path).split("_")[3][4:]
except IndexError:
pass

sio.savemat(os.path.join(output_root, 'erdetect_data.mat'), output_dict)

Expand Down
5 changes: 5 additions & 0 deletions erdetect/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def create_default_config():
config['preprocess']['late_re_referencing']['CAR_by_variance'] = -1 #
config['preprocess']['late_re_referencing']['stim_excl_epoch'] = (-1.0, 2.0)
config['preprocess']['late_re_referencing']['channel_types'] = ('ECOG', 'SEEG', 'DBS') # the type of channels that will be included for late re-referencing
config['preprocess']['bipolar_montage'] = False # if True, channels will be re-referenced to a bipolar setting

config['trials'] = dict()
config['trials']['trial_epoch'] = (-1.0, 2.0) # the time-span (in seconds) relative to the stimulus onset that will be used to extract the signal for each trial
Expand Down Expand Up @@ -385,6 +386,9 @@ def retrieve_config_tuple(json_dict, ref_config, level1, level2, level3=None, op
return False
if not retrieve_config_string(json_config, config, 'preprocess', 'late_re_referencing', 'method', options=('CAR', 'CAR_headbox')):
return False

if not retrieve_config_bool(json_config, config, 'preprocess', 'bipolar_montage'):
return False

if config['preprocess']['late_re_referencing']['method'] in ('CAR', 'CAR_headbox'):
retrieve_config_number(json_config, config, 'preprocess', 'late_re_referencing', 'CAR_by_variance')
Expand Down Expand Up @@ -533,6 +537,7 @@ def write_config(filepath):
' "channel_types": ' + json.dumps(_config['preprocess']['early_re_referencing']['channel_types']) + '\n' \
' },\n' \
' "line_noise_removal": "' + _config['preprocess']['line_noise_removal'] + '",\n' \
' "bipolar_montage": ' + ('true' if _config['preprocess']['bipolar_montage'] else 'false') + ',\n' \
' "late_re_referencing": {\n' \
' "enabled": ' + ('true' if _config['preprocess']['late_re_referencing']['enabled'] else 'false') + ',\n' \
' "method": "' + _config['preprocess']['late_re_referencing']['method'] + '",\n'
Expand Down
72 changes: 72 additions & 0 deletions erdetect/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from matplotlib.figure import Figure
from ieegprep.utils.misc import is_number

import re
import numpy as np
from collections import defaultdict


def create_figure(width=500, height=500, on_screen=False):
"""
Expand Down Expand Up @@ -116,3 +120,71 @@ def numbers_to_padded_string(values, width=0, pos_space=True, separator=', '):
padded_str += padded_values[iValue]

return padded_str

def mount_bipolar(monopolar_data_epoched:np.ndarray, monopolar_channels:list, sep:str='-'):
"""
Computes a bipolar montage of the input data.

Channels are assumed to follow a naming convention: <prefix><number> (e.g., A1, A2, B10).

Adjacent channels are those with:
- same prefix
- consecutive numbers

Parameters
----------

monopolar_data_epoched : ndarray (n_channels, n_epochs, n_samples)
Input monopolar data array.

monopolar_channels : list of str (n_channels)
List of input channel names.

sep : str, optional
Divider that separates channels in bipolar channel names.

Returns
-------

bipolar_data : ndarray (n_bipolar_channels, n_samples)
Output bipolar data array.

bipolar_channels : list of str (n_bipolar_channels)
List of output channel names.
"""

bipolar_data_epoched = []
bipolar_channels = []

ch_idx = {ch: i for i, ch in enumerate(monopolar_channels)}

# Parse channels into their prefix + their number
parsed = []
pattern = re.compile(r"^([A-Za-z]+)(\d+)$")
for ch in monopolar_channels:
match = pattern.match(ch)
if match:
prefix, num = match.groups()
parsed.append((prefix, int(num), ch))

# Group channels by prefix
groups = defaultdict(list)
for prefix, num, ch in parsed:
groups[prefix].append((num, ch))

# Compute bipolar channel by grouping
for prefix, items in groups.items():
items.sort(key=lambda x: x[0])
for (n1, ch1), (n2, ch2) in zip(items[:-1], items[1:]):
if n2 > n1:
idx1 = ch_idx[ch1]
idx2 = ch_idx[ch2]
bipolar_data_epoched.append(monopolar_data_epoched[idx1,:,:] - monopolar_data_epoched[idx2,:,:])
bipolar_channels.append(f"{ch1}{sep}{ch2}")

if not bipolar_data_epoched:
raise RuntimeError("No bipolar channels created.")

bipolar_data_epoched = np.stack(bipolar_data_epoched, axis=0)

return bipolar_data_epoched, bipolar_channels
2 changes: 1 addition & 1 deletion erdetect/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.6.2"
__version__ = "2.7.0"