Contributor: Hao Zhu (haozhu@cuhk.edu.hk, hao.zhu.808@gmail.com)
Created: 2024/10/02
Updated: 2025/03/05
This tutorial is based on:
system: Ubuntu 24.04.2 LTS 64-bit
Python 3.12.4
dependencies:
- mne 1.8.0
- re 2.2.1
- numpy 1.26.4
- seaborn 0.13.2
- matplotlib 3.8.4
- scipy 1.13.1
- pingouin 0.5.5
- tqdm 4.66.4
- pandas 2.2.2
This tutorial contains AI-generated content
Introduction¶
This tutorial provides a guide to preprocessing sEEG data using Python.
We will walk through essential stages such as data loading, preprocessing, feature extraction, and visualization, utilizing widely adopted libraries including MNE, NumPy, and SciPy.
Recognizing that different laboratories adopt diverse preprocessing pipelines, this tutorial presents a hybrid pipeline that amalgamates successful methods from various labs. Our approach has been validated on sEEG datasets collected from Shenzhen, Guangzhou, and Hong Kong. Notably, this is the second version of our pipeline, which incorporates additional analysis techniques and resolves previous bugs.
For demonstration purposes, we use a sample dataset from a real subject recorded at Shenzhen Second People's Hospital. The experiment involves a pure passive listening task in which the subject listens to a /ga4/ syllable 600 times with a stimulus onset asynchrony (SOA) of 410 ms. A trigger pulse, which is delayed relative to the actual sound, is delivered to a separate channel in the acquisition system to mark the events.
The raw data, originally acquired via a Nicolet system in EDF format, has undergone initial preprocessing steps including segmentation, bad channel removal, and conversion to FIF format using MNE-Python. Events are captured in the "TRIG" channel, where varying pulse magnitudes denote different events (e.g., a high-magnitude pulse typically indicates that a marker has been sent by the PC). These events are then extracted and saved in an .eve format ([timestamp, 0, event_id]) for further analysis.
This tutorial requires only one file – the preprocessed raw data file ending in .fif.
Load modules¶
This section imports all the necessary libraries for the sEEG data preprocessing tutorial:
- mne: For reading, preprocessing, and analyzing EEG data.
- re: To perform regular expression operations (e.g., pattern matching in text).
- tqdm: To display progress bars during loops and long computations.
- time: For measuring the duration of operations.
- matplotlib.pyplot: For creating plots and visualizations.
- numpy: For numerical computations and handling array data.
- pandas: For dataFrame operation
- pingouin: For advanced statistical analysis.
- scipy.signal: For signal processing functions (e.g., filtering, spectral analysis).
- scipy.stats: For statistical tests and probability distributions.
- concurrent.futures: For parallel execution of tasks using ProcessPoolExecutor.
import mne, re, tqdm, time
import matplotlib.pyplot as plt
import numpy as np
import pingouin as pg
import pandas as pd
from typing import Tuple, Optional, Union, List, Dict
from scipy import signal, stats
from matplotlib.figure import Figure
from concurrent.futures import ProcessPoolExecutor, as_completed
%matplotlib inline
# %matplotlib qt5
pathways & pre-defined variables¶
This section defines various file paths and pre-defined variables that are used throughout the tutorial.
It sets up the color map for visualizations and specifies directories and filenames for data input and output.
# Define the base directory for data storage.
data_path = ''
# Define subject information and processing status.
subject_id = 'QUJ3YH24KM'
# Set the path for saving figures generated during the analysis.
img_path = f'{data_path}figures/'
# Define a list of channels identified as epileptic.
epileptic_contacts = ["H'1","H'2","H'3","H'4","H'5",
"J'1","J'2","J'3","J'4","J'5",
"A'1","A'2","A'3","A'4","A'5",
"H1","H2","H3","H4","H5",
"J1","J2","J3","J4","J5",
"A1","A2","A3","A4","A5"]
# Define filenames for various data files used in the pipeline.
subject_Datadir = f'{data_path}{subject_id}_'
eventFile_name = f'{subject_Datadir}events.eve' # File containing event information.
rawFile_name = f'{subject_Datadir}raw_ieeg.fif' # Raw sEEG data file in FIF format.
rawERPFile_name = f'{subject_Datadir}raw-erp_ieeg.fif' # Raw event-related potentials (ERP) data file.
epochERPFile_name = f'{subject_Datadir}epoch_erp-epo.fif' # Preprocessed ERP epochs.
ERP_DropEpochlLog_name = f"{subject_Datadir}erp_epoch_drop.npy" # Log file for dropped ERP epochs.
rawBBPFile_name = f'{subject_Datadir}raw-bbp_ieeg.fif' # Raw broadband power (BBP) data file.
epochBBPFile_name = f'{subject_Datadir}epoch_bbp-epo.fif' # Preprocessed BBP epochs.
BBP_DropEpochlLog_name = f"{subject_Datadir}bbp_epoch_drop.npy" # Log file for dropped BBP epochs.
rawHGPFile_name = f'{subject_Datadir}raw-hgp_ieeg.fif' # Raw high-gamma power (HGP) data file.
epochHGPFile_name = f'{subject_Datadir}epoch_hgp-epo.fif' # Preprocessed HGP epochs.
HGP_DropEpochlLog_name = f"{subject_Datadir}hgp_epoch_drop.npy" # Log file for dropped HGP epochs.
pre-processing¶
create events¶
This section demonstrates how to extract events from the raw data. The process includes reading the raw data, isolating the trigger channel, adjusting the trigger amplitude, extracting events, applying a time shift to align the events with the actual stimulus onset, and finally saving the events to a file.
def extract_events_from_raw(
raw_file: str,
event_file: str,
stim_channel: str = 'TRIG',
click_shift_sec: float = 0.194
) -> Tuple[np.ndarray, int]:
"""
Extracts and processes events from raw sEEG data with temporal alignment correction.
Parameters:
-----------
raw_file : str
Path to the raw FIF data file.
event_file : str
Path where the processed events file (.eve) will be saved.
stim_channel : str, optional
The channel name used for event triggers (default: 'TRIG').
click_shift_sec : float, optional
Time in seconds to shift events for stimulus alignment (default: 0.194).
Returns:
--------
events : np.ndarray
Processed event array of shape (n_events, 3)
click_shift : int
Applied temporal shift in samples
Notes:
------
- Removes first/last events to avoid edge artifacts
- Applies fixed temporal correction for stimulus alignment
"""
# Load data with memory mapping for large files
raw = mne.io.read_raw_fif(raw_file, preload=True) # preload=True for in-memory operations
# Preprocess trigger channel using MNE's built-in methods
trig_channel = raw.copy().pick([stim_channel])
trig_channel._data = abs(trig_channel.get_data() * 1e3) # Amplitude scaling for better detection
# Event detection with default parameters
raw_events = mne.find_events(trig_channel, stim_channel=stim_channel)
# Remove the first and last events, which is the marker for experi
events = raw_events[1:-1]
# Optionally, inspect event differences if needed
event_diffs = np.diff(events, axis=0)
# Debug: Uncomment the following line to inspect event differences
# print(np.where(event_diffs[:, 0] > 844)[0], np.where(event_diffs[:, 0] < 836), event_diffs.mean(axis=0))
# Calculate the click shift in samples using the sampling frequency
click_shift = int(click_shift_sec * raw.info['sfreq'])
# Adjust event onset times by subtracting the click shift.
# This aligns the event times to the actual stimulus onset.
events[:, 0] = events[:, 0] - click_shift
# Write the processed events to file, overwriting if necessary
mne.write_events(event_file, events, overwrite=True)
return events, click_shift
# Example usage:
events, click_shift = extract_events_from_raw(rawFile_name, eventFile_name)
print("Applied click shift (in samples):", click_shift)
print("Extracted events:\n", events)
Opening raw data file QUJ3YH24KM_raw_ieeg.fif... Isotrak not found Range : 3706880 ... 4259840 = 1810.000 ... 2080.000 secs Ready. Reading 0 ... 552960 = 0.000 ... 270.000 secs... 602 events found on stim channel TRIG Event IDs: [ 1 5 20] Overwriting existing file. Applied click shift (in samples): 397 Extracted events: [[3726090 0 1] [3726930 0 1] [3727772 0 1] ... [4227377 0 1] [4228214 0 1] [4229055 0 1]]
Breakdown¶
In this section, the common preprocessing pipeline is broken down into individual cells to load the data, drop unwanted channels, remove line noise, re-reference, and apply filtering.
This modular approach makes each preprocessing step clear and allows for easy adaptation or expansion.
The final processed data (in the 70–150 Hz frequency range) is tailored for high gamma band power analysis.
High gamma activity is a crucial indicator of local cortical processing and task-related neural dynamics, making it highly relevant for studies on cognitive functions and clinical applications.
load and drop epileptic contacts¶
The following block loads the raw sEEG data and event information, then drops channels corresponding to epileptic contacts and the trigger channel.
# Load the raw data from the FIF file.
raw = mne.io.read_raw_fif(rawFile_name, preload=True)
# Read events from the provided event file.
events = mne.read_events(eventFile_name)
# Drop channels that correspond to epileptic contacts.
raw.drop_channels(epileptic_contacts)
# Also drop the trigger channel ("TRIG") from the data as it is not needed for further analysis.
raw.drop_channels(["TRIG"])
Opening raw data file QUJ3YH24KM_raw_ieeg.fif... Isotrak not found Range : 3706880 ... 4259840 = 1810.000 ... 2080.000 secs Ready. Reading 0 ... 552960 = 0.000 ... 270.000 secs...
General | ||
---|---|---|
Filename(s) | QUJ3YH24KM_raw_ieeg.fif | |
MNE object type | Raw | |
Measurement date | 2024-12-23 at 14:34:11 UTC | |
Participant | 100011917 | |
Experimenter | Unknown | |
Acquisition | ||
Duration | 00:04:30 (HH:MM:SS) | |
Sampling frequency | 2048.00 Hz | |
Time points | 552,961 | |
Channels | ||
EEG | ||
Head & sensor digitization | Not available | |
Filters | ||
Highpass | 0.00 Hz | |
Lowpass | 1024.00 Hz |
remove line noise channel¶
def line_noise_removal(raw_object: mne.io.Raw) -> np.ndarray:
"""
Remove channels dominated by line noise.
Parameters:
raw_object: instance of MNE Raw containing the data.
Returns:
indices of channels to be removed based on the line noise threshold.
"""
# Extract data from the raw object.
raw_data = raw_object.get_data()
# Define the target frequency (50 Hz) and quality factor for the notch filter.
f0 = 50.0 # Target frequency to be retained (Hz)
Q = 30.0 # Quality factor, controlling the bandwidth of the filter
# Get the sampling frequency from the raw object's info.
sfreq = raw_object.info['sfreq']
# Create a second-order IIR peak (notch) filter at 50 Hz.
b, a = signal.iirpeak(f0 / (sfreq / 2), Q)
# Apply the filter along the time axis (axis=1) for each channel.
raw_xLN = signal.filtfilt(b, a, raw_data, axis=1)
# Calculate the power (mean squared value) for each channel.
raw_LN = np.mean(raw_xLN**2, axis=1)
# Compute the median and mean absolute deviation of the filtered signal.
raw_xLN_median = np.median(raw_xLN)
raw_xLN_mad = np.mean(np.abs(raw_xLN - raw_xLN_median))
# Define a threshold: channels with power greater than the median plus 10 times MAD.
LN_threshold = raw_xLN_median + 10 * raw_xLN_mad
# Identify indices of channels exceeding the noise threshold.
removal = np.where(raw_LN > LN_threshold)[0]
return removal
# Identify channels that are dominated by line noise.
line_noise_channel = line_noise_removal(raw)
line_noise_channel
array([], dtype=int64)
re-reference with ESR¶
The primary purpose of re-referencing is to mitigate the impact of the reference electrode on the recorded sEEG signals. In this sample dataset, sEEG signals were online referenced with contact (1 or 2) in the white matter (monopolar method). This reference electrode can pick up artifacts that affect all other channels, potentially distorting the true brain signals of interest. And it reduces biases introduced by the original reference electrode and minimizes spurious correlations between channels.
It can help improve the overall signal quality and signal-to-noise ratio (SNR). By choosing an appropriate re-referencing method, researchers can reduce common noise and artifacts that may be present in the original reference. You may refer to this paper for more comprehensive info about which method to use: https://pmc.ncbi.nlm.nih.gov/articles/PMC6495648/
I selected Electrode Shaft Reference (ESR: re-referenced each channel to the average signal of all channels on the same shaft) for the dataset collected from Shenzhen after scrutinizing all possible methods.
def find_electrode_indices(contacts):
"""
Given a list of channel names (contacts), find index ranges corresponding to each electrode shaft.
The electrode is identified by its prefix (letters and an optional apostrophe).
Parameters:
contacts: list of channel names.
Returns:
Dictionary mapping electrode prefix to a tuple (start_index, end_index).
"""
electrode_indices = {}
current_electrode = None
start_index = 0
# Regex to extract the prefix (non-digit characters, with optional apostrophe) from each contact.
prefix_pattern = re.compile(r"^[A-Z]+'?")
for i, contact in enumerate(contacts):
# Extract the prefix using regex.
match = prefix_pattern.match(contact)
if match:
prefix = match.group()
# When encountering a new electrode, save the previous electrode's index range.
if prefix != current_electrode:
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, i - 1)
# Update current electrode and set new start index.
current_electrode = prefix
start_index = i
# Add the last electrode's index range.
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, len(contacts) - 1)
return electrode_indices
# Retrieve channel names from raw data.
contacts = raw.info['ch_names']
# Find the index ranges for each electrode based on the channel names.
electrode_indices = find_electrode_indices(contacts)
electrode_indices
{'O': (0, 15), 'E': (16, 28), 'X': (29, 42), 'V': (43, 52), 'J': (53, 63), 'H': (64, 74), 'A': (75, 83), "A'": (84, 92), "H'": (93, 101), "J'": (102, 112), "V'": (113, 122), "X'": (123, 134), "E'": (135, 146), "O'": (147, 162)}
# Copy the raw data to work on and create an array for storing the electrode shaft averages.
raw_data = raw.copy().get_data()
electrode_shaft_average = np.zeros(raw_data.shape)
# Loop over each electrode group and compute the average signal across contacts in that group.
for electrode in electrode_indices.keys():
electrode_start, electrode_end = electrode_indices[electrode]
# Compute the average (ESR) for the current electrode shaft.
esr = np.mean(raw_data[electrode_start:electrode_end+1],axis=0)
# Broadcast the average back to each channel in the electrode shaft.
electrode_shaft_average[electrode_start:electrode_end+1] = esr
# Re-reference the data by subtracting the electrode shaft average.
raw._data = raw_data - electrode_shaft_average
filters¶
# Apply a notch filter to remove line noise harmonics (50, 100, 150, and 200 Hz).
raw.notch_filter([50,100,150,200])
# Apply a band-pass filter from 0.1 Hz to 40 Hz using an FIR filter.
raw.filter(l_freq=70, h_freq=150, method='fir')
Filtering raw data in 1 contiguous segment Setting up band-stop filter FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandstop filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower transition bandwidth: 0.50 Hz - Upper transition bandwidth: 0.50 Hz - Filter length: 13517 samples (6.600 s)
[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.6s
Filtering raw data in 1 contiguous segment Setting up band-pass filter from 70 - 1.5e+02 Hz FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandpass filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower passband edge: 70.00 - Lower transition bandwidth: 17.50 Hz (-6 dB cutoff frequency: 61.25 Hz) - Upper passband edge: 150.00 Hz - Upper transition bandwidth: 37.50 Hz (-6 dB cutoff frequency: 168.75 Hz) - Filter length: 387 samples (0.189 s)
[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 1.3s [Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.4s [Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 1.0s
General | ||
---|---|---|
Filename(s) | QUJ3YH24KM_raw_ieeg.fif | |
MNE object type | Raw | |
Measurement date | 2024-12-23 at 14:34:11 UTC | |
Participant | 100011917 | |
Experimenter | Unknown | |
Acquisition | ||
Duration | 00:04:30 (HH:MM:SS) | |
Sampling frequency | 2048.00 Hz | |
Time points | 552,961 | |
Channels | ||
EEG | ||
Head & sensor digitization | Not available | |
Filters | ||
Highpass | 70.00 Hz | |
Lowpass | 150.00 Hz |
# for interactive plot, make sure you adjust these tow lines in the first cell:
'''
Line 12 %matplotlib inline
Line 13 %matplotlib qt5
'''
# For interactive plot (with GUI), comment out %matplotlib inline.
# For static plot (non-GUI), comment out %matplotlib qt5.
# Plot the raw data along with the events.
raw.plot(events=events)
Using matplotlib as 2D backend.
save to HGP (70-150hz) raw file¶
raw.save(rawHGPFile_name, overwrite=True)
Overwriting existing file. Writing /home/haozhu/Desktop/demo/QUJ3YH24KM_raw-hgp_ieeg.fif Closing /home/haozhu/Desktop/demo/QUJ3YH24KM_raw-hgp_ieeg.fif [done]
Broad Band Power¶
In this section, I summarize the breakdown pipeline to a function and use it to compute broad band power.
Broad band power analysis captures neural activity over a wider frequency range (0.1–200 Hz).
def process_seeg_data(raw_file: str,
event_file: str,
save_file: str,
epileptic_contacts: list = None,
trigger_channel: str = "TRIG",
drop_line_noise: bool = True,
filter_low: float = 0.1,
filter_high: float = 200.0,
plot: bool = False) -> mne.io.Raw:
"""
Process sEEG data with options for channel removal, line noise correction, re-referencing, and filtering.
Parameters:
raw_file: Path to the raw FIF file.
event_file: Path to the events file.
save_file: Path to save the processed raw data.
epileptic_contacts: List of channel names to drop (epileptic contacts).
Default list is used if None.
trigger_channel: Name of the trigger channel to drop.
drop_line_noise: Whether to drop channels dominated by line noise.
filter_low: Lower bound for band-pass filtering.
filter_high: Upper bound for band-pass filtering.
plot: If True, plot the raw data with events.
Returns:
Processed mne.io.Raw object.
"""
# --- Step 1: Load the data ---
raw = mne.io.read_raw_fif(raw_file, preload=True)
events = mne.read_events(event_file)
# --- Step 2: Drop unwanted channels ---
raw.drop_channels(epileptic_contacts)
raw.drop_channels([trigger_channel])
# --- Step 3: Remove line noise dominated channels (optional) ---
def line_noise_removal(raw_object: mne.io.Raw,
notch_freq: float = 50.0,
notch_Q: float = 30.0) -> np.ndarray:
"""
Identify indices of channels dominated by line noise.
"""
raw_data = raw_object.get_data()
sfreq = raw_object.info['sfreq']
# Create a second-order IIR peak (notch) filter.
b, a = signal.iirpeak(notch_freq / (sfreq / 2), notch_Q)
# Apply the filter across time (axis=1)
filtered_data = signal.filtfilt(b, a, raw_data, axis=1)
# Calculate power for each channel.
power = np.mean(filtered_data**2, axis=1)
median_val = np.median(filtered_data)
mad = np.mean(np.abs(filtered_data - median_val))
threshold = median_val + 10 * mad
removal_indices = np.where(power > threshold)[0]
return removal_indices
if drop_line_noise:
ln_indices = line_noise_removal(raw)
# Convert indices to channel names and drop these channels.
ln_channel_names = [raw.info['ch_names'][i] for i in ln_indices]
raw.drop_channels(ln_channel_names)
# --- Step 4: Re-reference with Electrode Shaft Reference (ESR) ---
def find_electrode_indices(contacts: list) -> dict:
"""
Given a list of channel names, find index ranges corresponding to each electrode shaft.
Electrode shafts are identified by their prefix (letters with an optional apostrophe).
"""
electrode_indices = {}
current_electrode = None
start_index = 0
prefix_pattern = re.compile(r"^[A-Z]+'?")
for i, contact in enumerate(contacts):
match = prefix_pattern.match(contact)
if match:
prefix = match.group()
if prefix != current_electrode:
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, i - 1)
current_electrode = prefix
start_index = i
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, len(contacts) - 1)
return electrode_indices
contacts = raw.info['ch_names']
electrode_indices = find_electrode_indices(contacts)
print(electrode_indices)
# Create a copy of the raw data for re-referencing.
raw_data = raw.copy().get_data()
electrode_shaft_average = np.zeros(raw_data.shape)
# Compute and subtract the average signal for each electrode shaft.
for electrode, (start, end) in electrode_indices.items():
esr = np.mean(raw_data[start:end+1], axis=0)
electrode_shaft_average[start:end+1] = esr
raw._data = raw_data - electrode_shaft_average
# --- Step 5: Filtering ---
raw.notch_filter([50,100,150,200])
raw.filter(l_freq=filter_low, h_freq=filter_high, method='fir')
# --- (Optional) Step 6: Plotting ---
if plot:
raw.plot(events=events)
# --- Step 7: Save processed data ---
raw.save(save_file, overwrite=True)
return raw
# Call the process_seeg_data function to preprocess the raw data for broad band power analysis.
processed_raw = process_seeg_data(rawFile_name, eventFile_name, rawBBPFile_name, epileptic_contacts,filter_low=0.1,filter_high=200)
Opening raw data file QUJ3YH24KM_raw_ieeg.fif... Isotrak not found Range : 3706880 ... 4259840 = 1810.000 ... 2080.000 secs Ready. Reading 0 ... 552960 = 0.000 ... 270.000 secs... {'O': (0, 15), 'E': (16, 28), 'X': (29, 42), 'V': (43, 52), 'J': (53, 63), 'H': (64, 74), 'A': (75, 83), "A'": (84, 92), "H'": (93, 101), "J'": (102, 112), "V'": (113, 122), "X'": (123, 134), "E'": (135, 146), "O'": (147, 162)} Filtering raw data in 1 contiguous segment Setting up band-stop filter FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandstop filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower transition bandwidth: 0.50 Hz - Upper transition bandwidth: 0.50 Hz - Filter length: 13517 samples (6.600 s)
[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.5s
Filtering raw data in 1 contiguous segment Setting up band-pass filter from 0.1 - 2e+02 Hz FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandpass filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower passband edge: 0.10 - Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz) - Upper passband edge: 200.00 Hz - Upper transition bandwidth: 50.00 Hz (-6 dB cutoff frequency: 225.00 Hz) - Filter length: 67585 samples (33.000 s)
[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 1.2s [Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.2s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 1.0s
Overwriting existing file. Writing /home/haozhu/Desktop/demo/QUJ3YH24KM_raw-bbp_ieeg.fif
[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 2.2s
Closing /home/haozhu/Desktop/demo/QUJ3YH24KM_raw-bbp_ieeg.fif [done]
ERP¶
In this section, we adapt the pipeline for ERP analysis focus on lower band (0.1-40Hz).
# Call the process_seeg_data_erp function to preprocess the raw data for ERP analysis.
processed_raw = process_seeg_data(rawFile_name, eventFile_name, rawERPFile_name, epileptic_contacts,filter_low=0.1,filter_high=40)
Opening raw data file QUJ3YH24KM_raw_ieeg.fif... Isotrak not found Range : 3706880 ... 4259840 = 1810.000 ... 2080.000 secs Ready. Reading 0 ... 552960 = 0.000 ... 270.000 secs... {'O': (0, 15), 'E': (16, 28), 'X': (29, 42), 'V': (43, 52), 'J': (53, 63), 'H': (64, 74), 'A': (75, 83), "A'": (84, 92), "H'": (93, 101), "J'": (102, 112), "V'": (113, 122), "X'": (123, 134), "E'": (135, 146), "O'": (147, 162)} Filtering raw data in 1 contiguous segment Setting up band-stop filter FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandstop filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower transition bandwidth: 0.50 Hz - Upper transition bandwidth: 0.50 Hz - Filter length: 13517 samples (6.600 s)
[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.6s
Filtering raw data in 1 contiguous segment Setting up band-pass filter from 0.1 - 40 Hz FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandpass filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower passband edge: 0.10 - Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz) - Upper passband edge: 40.00 Hz - Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz) - Filter length: 67585 samples (33.000 s)
[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 1.3s [Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.2s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.9s
Overwriting existing file. Writing /home/haozhu/Desktop/demo/QUJ3YH24KM_raw-erp_ieeg.fif
[Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 2.2s
Closing /home/haozhu/Desktop/demo/QUJ3YH24KM_raw-erp_ieeg.fif [done]
Epoching¶
HGP¶
This section introduces a modular function designed to streamline sEEG data epoching.
The process_seeg_epoch function encapsulates steps such as reading raw data and events, epoch extraction, automatic artifact rejection, and manual inspection.
def process_seeg_epoch(raw_file: str,
event_file: str,
epoch_save_file: str,
drop_log_file: str,
tmin: float = -1.0,
tmax: float = 1.0,
baseline: Optional[Union[Tuple[float, float], str]] = None,
detrend: int = 1,
epoch_threshold: float = 350e-6,
rejection_ratio: float = 0.5,
drop_by_sd_factor: float = 4.0,
do_auto_drop: bool = True,
do_manual_drop: bool = True,
do_log_drop: bool = False) -> mne.Epochs:
"""
Process sEEG data: reads raw data and events, epochs the data,
automatically detects and drops epochs with improbable artifacts,
allows for manual inspection, logs dropped epochs, and finally saves the epochs.
Parameters:
raw_file (str): Path to the raw sEEG file.
event_file (str): Path to the event file.
epoch_save_file (str): Path to save the processed epoch object.
drop_log_file (str): Path to save the drop log file.
tmin (float): Start time of the epoch window.
tmax (float): End time of the epoch window.
baseline: Baseline correction to use (default: None).
detrend (int or float): Detrending order for epoching.
epoch_threshold (float): Amplitude threshold (in SI units) for peak-based rejection.
rejection_ratio (float): Ratio of channels exceeding threshold to flag an epoch.
drop_by_sd_factor (float): Factor for standard deviation based threshold.
do_auto_drop (bool): Whether to automatically drop epochs meeting criteria.
do_manual_drop (bool): Whether to allow manual inspection (calls drop_bad).
do_log_drop (bool): If True, read the saved drop log (.npy file) and drop epochs based on it.
Returns:
epoch (mne.Epochs): The processed Epochs object.
"""
# --- Load raw data and events ---
raw = mne.io.read_raw(raw_file, preload=True)
events = mne.read_events(event_file)
epoch = mne.Epochs(raw, events, tmin=tmin, tmax=tmax, baseline=baseline,
reject=None, preload=True, detrend=detrend)
# --- Auto-detection of improbable epochs ---
epoch_data = epoch.get_data(copy=True)
contact_number = epoch_data.shape[1]
epoch_data_mean = epoch_data.mean(axis=0)
epoch_data_sd = epoch_data.std(axis=0)
epoch_peak_value = np.abs(epoch_data).max(axis=2)
auto_drop_indices = set()
# Criterion 1: Peak amplitude threshold
for ind, each_epoch in enumerate(epoch_peak_value):
ratio = len(np.where(each_epoch > epoch_threshold)[0]) / contact_number
if ratio > rejection_ratio:
print("Peak criterion: Epoch %d has ratio %.2f" % (ind, ratio))
auto_drop_indices.add(ind)
# Criterion 2: Amplitude exceeding mean+4*SD threshold across epochs
for ind, each_epoch in enumerate(epoch_data):
ratio = len(set(np.where(each_epoch > (epoch_data_mean + drop_by_sd_factor * epoch_data_sd))[0])) / contact_number
if ratio > rejection_ratio:
print("SD criterion: Epoch %d has ratio %.2f" % (ind, ratio))
auto_drop_indices.add(ind)
if do_auto_drop and auto_drop_indices:
print("Auto-dropping epochs:", auto_drop_indices)
epoch.drop(list(auto_drop_indices))
# --- Manual inspection drop ---
if do_manual_drop:
epoch.plot(block=True) # Opens an interactive plot for manual review
epoch.drop_bad() # Drops epochs marked as bad manually
# --- Save drop log using numpy.save ---
drop_epoch = np.array([m for m, log in enumerate(epoch.drop_log) if 'USER' in log])
drop_stats = np.round(epoch.drop_log_stats(), 2)
np.save(drop_log_file, drop_epoch)
print(f"{len(drop_epoch)} epoch(s) dropped ({drop_stats}% of total epochs), saved to '{drop_log_file}'")
# --- Quick drop using saved drop log if do_log_drop is True ---
if do_log_drop:
try:
drop_list = np.load(drop_log_file)
if drop_list.size > 0:
epoch.drop(drop_list)
print("Quick dropped epochs from saved log:", drop_list)
except Exception as e:
print("Error reading drop log file:", e)
# --- Save processed epochs ---
epoch.save(epoch_save_file, overwrite=True)
print(f"Processed epochs saved to '{epoch_save_file}'")
return epoch
processed_epoch = process_seeg_epoch(
raw_file=rawHGPFile_name,
event_file=eventFile_name,
epoch_save_file=epochHGPFile_name,
drop_log_file=HGP_DropEpochlLog_name,
do_auto_drop = 0,
do_manual_drop = 0,
do_log_drop = 1
)
Opening raw data file QUJ3YH24KM_raw-hgp_ieeg.fif... Isotrak not found Range : 3706880 ... 4259840 = 1810.000 ... 2080.000 secs Ready. Reading 0 ... 552960 = 0.000 ... 270.000 secs... Not setting metadata 600 matching events found No baseline correction applied 0 projection items activated Using data from preloaded Raw for 600 events and 4097 original time points ... 0 bad epochs dropped SD criterion: Epoch 73 has ratio 0.77 SD criterion: Epoch 74 has ratio 0.75 SD criterion: Epoch 75 has ratio 0.75 SD criterion: Epoch 76 has ratio 0.72 SD criterion: Epoch 77 has ratio 0.72 SD criterion: Epoch 175 has ratio 0.50 SD criterion: Epoch 176 has ratio 0.52 Dropped 49 epochs: 73, 74, 75, 76, 77, 93, 94, 95, 96, 97, 101, 102, 103, 104, 105, 175, 176, 206, 207, 208, 209, 210, 262, 263, 264, 265, 266, 441, 442, 443, 444, 445, 493, 494, 495, 496, 497, 498, 499, 506, 507, 508, 509, 510, 529, 530, 531, 532, 533 Quick dropped epochs from saved log: [ 73 74 75 76 77 93 94 95 96 97 101 102 103 104 105 175 176 206 207 208 209 210 262 263 264 265 266 441 442 443 444 445 493 494 495 496 497 498 499 506 507 508 509 510 529 530 531 532 533] Overwriting existing file. Overwriting existing file. Processed epochs saved to 'QUJ3YH24KM_epoch_hgp-epo.fif'
BBP¶
processed_epoch = process_seeg_epoch(
raw_file=rawBBPFile_name,
event_file=eventFile_name,
epoch_save_file=epochBBPFile_name,
drop_log_file=BBP_DropEpochlLog_name,
tmin = -0.1,
tmax = 0.4,
baseline = None,
do_auto_drop = 0,
do_manual_drop = 0,
do_log_drop = 1
)
Opening raw data file QUJ3YH24KM_raw-bbp_ieeg.fif... Isotrak not found Range : 3706880 ... 4259840 = 1810.000 ... 2080.000 secs Ready. Reading 0 ... 552960 = 0.000 ... 270.000 secs... Not setting metadata 600 matching events found No baseline correction applied 0 projection items activated Using data from preloaded Raw for 600 events and 1025 original time points ... 0 bad epochs dropped Dropped 49 epochs: 73, 74, 75, 76, 77, 93, 94, 95, 96, 97, 101, 102, 103, 104, 105, 175, 176, 206, 207, 208, 209, 210, 262, 263, 264, 265, 266, 441, 442, 443, 444, 445, 493, 494, 495, 496, 497, 498, 499, 506, 507, 508, 509, 510, 529, 530, 531, 532, 533 Quick dropped epochs from saved log: [ 73 74 75 76 77 93 94 95 96 97 101 102 103 104 105 175 176 206 207 208 209 210 262 263 264 265 266 441 442 443 444 445 493 494 495 496 497 498 499 506 507 508 509 510 529 530 531 532 533] Overwriting existing file. Overwriting existing file. Processed epochs saved to 'QUJ3YH24KM_epoch_bbp-epo.fif'
ERP¶
processed_epoch = process_seeg_epoch(
raw_file=rawERPFile_name,
event_file=eventFile_name,
epoch_save_file=epochERPFile_name,
drop_log_file=ERP_DropEpochlLog_name,
tmin = -0.1,
tmax = 0.4,
baseline = (-0.1,0),
do_auto_drop = 0,
do_manual_drop = 0,
do_log_drop = 1
)
Opening raw data file QUJ3YH24KM_raw-erp_ieeg.fif... Isotrak not found Range : 3706880 ... 4259840 = 1810.000 ... 2080.000 secs Ready. Reading 0 ... 552960 = 0.000 ... 270.000 secs... Not setting metadata 600 matching events found Applying baseline correction (mode: mean) 0 projection items activated Using data from preloaded Raw for 600 events and 1025 original time points ... 0 bad epochs dropped Dropped 49 epochs: 73, 74, 75, 76, 77, 93, 94, 95, 96, 97, 101, 102, 103, 104, 105, 175, 176, 206, 207, 208, 209, 210, 262, 263, 264, 265, 266, 441, 442, 443, 444, 445, 493, 494, 495, 496, 497, 498, 499, 506, 507, 508, 509, 510, 529, 530, 531, 532, 533 Quick dropped epochs from saved log: [ 73 74 75 76 77 93 94 95 96 97 101 102 103 104 105 175 176 206 207 208 209 210 262 263 264 265 266 441 442 443 444 445 493 494 495 496 497 498 499 506 507 508 509 510 529 530 531 532 533] Overwriting existing file. Overwriting existing file. Processed epochs saved to 'QUJ3YH24KM_epoch_erp-epo.fif'
HGP¶
In this section, we focus on analyzing High Gamma Power (HGP) from preprocessed sEEG epochs.
The workflow demonstrates in the flollowing steps:
1. Loading and resampling epoch data.
2. Calculating high gamma power (broadband power) per epoch/channel.
3. Selecting active contacts based on response windows and paired t-tests.
4. Plotting the results including waveforms and channel grouping.
Data Loader¶
def load_and_resample_epoch(epoch_file: str, resample_rate: int = 1000) -> Tuple[mne.Epochs, np.ndarray]:
"""
Load epoch data from a file and resample it.
Parameters:
- epoch_file: string, file path to the MNE epoch file.
- resample_rate: int, new sampling rate (default is 1000 Hz).
Returns:
- epoch: MNE Epochs object after resampling.
- epoch_data: numpy array of shape [n_epochs, n_channels, n_timepoints].
"""
# Load the epochs from file
epoch = mne.read_epochs(epoch_file)
# Resample the data to the desired sampling rate
epoch.resample(resample_rate)
# Get the data as a numpy array
epoch_data = epoch.copy().get_data()
return epoch, epoch_data
epoch, epoch_data = load_and_resample_epoch(epochHGPFile_name)
Reading /home/haozhu/Desktop/demo/QUJ3YH24KM_epoch_hgp-epo.fif ... Isotrak not found Found the data of interest: t = -1000.00 ... 1000.00 ms 0 CTF compensation matrices available Not setting metadata 551 matching events found No baseline correction applied 0 projection items activated
High Gamma Power Calculation¶
def calculate_broadband_power(epochs_data: np.ndarray,
fs: int = 1000,
baseline_onset: int = 1000,
smoothing_window_ms: int = 50,
test_plot: bool = False) -> np.ndarray:
"""
Calculate the broadband (high gamma) power for each epoch and channel.
Steps:
1. Apply the Hilbert transform to get the analytic signal.
2. Compute power as the square of the absolute value of the analytic signal.
3. Convert the power to a decibel (dB) scale.
4. Normalize power based on a baseline period (e.g., -100 ms to 0 ms relative to onset).
5. Smooth the power series using a simple moving average.
6. Optionally, if test_plot is True:
a. Print the shape of the broadband_power array.
b. Plot the average high gamma power over time for each channel (using indices 900:1400) with correct time labels.
Parameters:
- epochs_data (np.ndarray): Data array of shape [n_epochs, n_channels, n_timepoints].
- fs (int): Sampling frequency in Hz.
- baseline_onset (int): Index corresponding to stimulus onset for baseline extraction.
- smoothing_window_ms (int): Smoothing window length in milliseconds.
- test_plot (bool): If True, prints broadband_power shape and plots each channel's averaged high gamma power.
Returns:
- np.ndarray: Broadband power data with the same shape as epochs_data containing dB power values.
"""
n_epochs, n_channels, n_time = epochs_data.shape
broadband_power = np.zeros_like(epochs_data)
window_size = int(fs * smoothing_window_ms / 1000.0) # Convert window length from ms to samples
# Iterate over all epochs and channels
for i in tqdm.tqdm(range(n_epochs), desc="Processing epochs"):
for j in range(n_channels):
# Apply Hilbert transform to get analytic signal and calculate instantaneous power
analytic_signal = signal.hilbert(epochs_data[i, j])
power = np.abs(analytic_signal)**2
# Convert power to decibel scale; add a small constant to avoid log(0)
power_series = 10 * np.log10(power + 1e-20)
# Normalize power by subtracting the mean of the baseline period (-100ms to 0ms)
baseline_data = power_series[baseline_onset - 100:baseline_onset]
baseline_mean = np.mean(baseline_data)
power_series -= baseline_mean
# Smooth the power series with a moving average filter
power_series = np.convolve(power_series, np.ones(window_size) / window_size, mode='same')
broadband_power[i, j] = power_series
if test_plot:
# Print the shape of the computed broadband power array
print("Broadband power shape:", broadband_power.shape)
# Plot the average high gamma power for each channel (averaging across epochs)
avg_power = broadband_power.mean(axis=0) # shape: [n_channels, n_timepoints]
for n in range(avg_power.shape[0]):
plt.plot(avg_power[n][900:1400])
# Set x-ticks: indices 0 to 500 correspond to time -100ms to 400ms relative to baseline onset (1000)
plt.xticks([0, 100, 200, 300, 400, 500], labels=[-100, 0, 100, 200, 300, 400])
plt.xlabel("Time (ms)")
plt.ylabel("High Gamma Power (dB)")
plt.title("Averaged High Gamma Power for each channel")
plt.show()
return broadband_power
hg_power = calculate_broadband_power(epoch_data,test_plot=True)
Processing epochs: 100%|█████████████████████████████████████████████████████████████████████████| 551/551 [00:06<00:00, 90.42it/s]
Broadband power shape: (551, 163, 2000)
Active Contact Selection¶
def calculate_response_windows(data: np.ndarray, fs: int = 1000, onsets: int = 1000, pre_window_ms: int = 50, post_window_ms: int = 300) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculate pre- and post-stimulus response windows by averaging power across specified windows.
Parameters:
- data (np.ndarray): Data array of shape [n_epochs, n_channels, n_timepoints].
- fs (int): Sampling frequency in Hz.
- onsets (int): Index for stimulus onset.
- pre_window_ms (int): Duration in ms before onset used for baseline response.
- post_window_ms (int): Duration in ms after onset used for response measurement.
Returns:
- Tuple containing:
- pre_responses (np.ndarray): Average power in the pre-stimulus window.
- post_responses (np.ndarray): Average power in the post-stimulus window.
"""
pre_samples = int(pre_window_ms * fs / 1000.0)
post_samples = int(post_window_ms * fs / 1000.0)
# Average across the pre-stimulus window (from onsets - pre_samples to onsets)
pre_responses = data[:, :, onsets - pre_samples:onsets].mean(axis=2)
# Average across the post-stimulus window (from onsets to onsets + post_samples)
post_responses = data[:, :, onsets:onsets + post_samples].mean(axis=2)
return pre_responses, post_responses
def perform_paired_t_test(pre_responses: np.ndarray, post_responses: np.ndarray, alpha: float = 0.001) -> pd.DataFrame:
"""
Perform paired t-tests between pre- and post-stimulus responses for each channel.
The function uses Bonferroni correction to adjust p-values for multiple comparisons.
Parameters:
- pre_responses (np.ndarray): Baseline responses.
- post_responses (np.ndarray): Post-stimulus responses.
- alpha (float): Significance threshold (default: 0.001).
Returns:
- pd.DataFrame: A DataFrame with columns ['T', 'p-val', 'p-corrected', 'significant'] for each channel.
"""
n_channels = pre_responses.shape[1]
results = pd.DataFrame(index=range(n_channels), columns=['T', 'p-val', 'p-corrected', 'significant'])
for channel in range(n_channels):
# Perform paired t-test for the current channel using pingouin
res = pg.ttest(pre_responses[:, channel], post_responses[:, channel], paired=True)
results.loc[channel, 'T'] = res['T'].values[0]
results.loc[channel, 'p-val'] = res['p-val'].values[0]
# Apply Bonferroni correction: multiply p-values by number of channels (cap at 1)
results['p-corrected'] = results['p-val'].astype(float) * n_channels
results['p-corrected'] = results['p-corrected'].apply(lambda x: min(x, 1))
results['significant'] = results['p-corrected'] < alpha
return results
def select_active_contacts(test_results: pd.DataFrame, ch_names: List[str]) -> Tuple[np.ndarray, List[str]]:
"""
Select active contacts based on significant differences between pre- and post-stimulus responses.
Parameters:
- test_results (pd.DataFrame): DataFrame containing t-test results.
- ch_names (List[str]): List of channel names.
Returns:
- Tuple containing:
- active_index (np.ndarray): Array of indices for significant channels.
- active_channel (List[str]): List of channel names corresponding to significant channels.
"""
active_index = np.where(test_results['significant'] == True)[0]
active_channel = [ch_names[n] for n in active_index]
return active_index, active_channel
def plot_channel_waveforms(hg_power: np.ndarray, epoch: mne.Epochs, active_index: np.ndarray,
start: int = 900, end: int = 1400) -> None:
"""
Plot the mean high gamma power waveform for each active channel.
For each active channel, a figure with two subplots is created:
- Left subplot: The mean waveform (averaged across all trials).
- Right subplot: A heatmap of single-trial high gamma power (for the given time window)
sorted by the trial’s peak latency. A dashed vertical line is added on each row
to indicate the time of the peak latency.
Parameters:
- hg_power (np.ndarray): High gamma power data with shape [n_epochs, n_channels, n_timepoints].
- epoch (mne.Epochs): MNE Epochs object (provides channel names).
- active_index (np.ndarray): Array of channel indices for active channels.
- start (int): Starting index for the time window to plot.
- end (int): Ending index for the time window to plot.
Returns:
- None
"""
for n in active_index:
plt.figure(figsize=(10, 4))
# Left subplot: Mean waveform (averaged across trials)
plt.subplot(1, 2, 1)
mean_waveform = hg_power.mean(axis=0)[n][start:end]
plt.plot(mean_waveform)
plt.title(f'{n} {epoch.ch_names[n]}')
plt.xticks([0, 100, 200, 300, 400, 500], labels=[-100, 0, 100, 200, 300, 400])
plt.xlabel("Time (ms)")
plt.ylabel("High Gamma Power (dB)")
plt.ylim(-1, 5)
# Right subplot: Heatmap of single-trial data sorted by peak latency
plt.subplot(1, 2, 2)
# Extract the data for the current channel in the given window
data = hg_power[:, n, start:end] # shape: [n_trials, n_timepoints]
# Determine the peak latency (index) for each trial
peak_latencies = np.argmax(data, axis=1)
# Sort the trials based on peak latency
sort_order = np.argsort(peak_latencies)
sorted_data = data[sort_order, :]
sorted_peak_latencies = peak_latencies[sort_order]
# Plot the heatmap using plt.imshow with origin set to "lower"
ax = plt.gca()
im = ax.imshow(sorted_data, aspect='auto', vmin=0, vmax=8, cmap='cividis', origin='lower')
ax.set_xticks([0, 100, 200, 300, 400, 500])
ax.set_xticklabels([-100, 0, 100, 200, 300, 400])
ax.set_ylabel("Sorted Trials")
ax.set_xlabel("Time (ms)")
# Add dashed vertical lines indicating the peak latency for each trial
n_trials = sorted_data.shape[0]
for i in range(n_trials):
peak = sorted_peak_latencies[i]
# Draw a vertical dashed line spanning the row (from i-0.5 to i+0.5)
ax.plot([peak, peak], [i - 0.5, i + 0.5], linestyle='--', color='white', linewidth=1)
plt.tight_layout()
plt.show()
hg_pre_response, hg_post_responses = calculate_response_windows(hg_power)
test_results = perform_paired_t_test(hg_pre_response, hg_post_responses)
active_index, active_channel = select_active_contacts(test_results, epoch.ch_names)
plot_channel_waveforms(hg_power, epoch, active_index)
active_index, active_channel
(array([ 16, 19, 20, 21, 22, 23, 24, 25, 26, 143, 144]), ['E1', 'E4', 'E5', 'E6', 'E7', 'E8', 'E9', 'E10', 'E11', "E'9", "E'10"])
plot active contacts within all channels¶
Plot all channels grouped by electrode with active channels highlighted
def find_electrode_indices(contacts: List[str]) -> Dict[str, Tuple[int, int]]:
"""
Determine electrode indices from channel names by grouping based on electrode prefix.
The function uses a regex to extract the non-digit prefix from each channel name and groups
consecutive channels that share the same prefix.
Parameters:
- contacts (List[str]): List of channel names.
Returns:
- Dict[str, Tuple[int, int]]: Dictionary mapping electrode prefix to a tuple (start_index, end_index).
"""
electrode_indices: Dict[str, Tuple[int, int]] = {}
current_electrode: Optional[str] = None
start_index = 0
# Regex to extract the prefix (non-digit characters) from the channel name
prefix_pattern = re.compile(r"^[A-Z]+'?")
for i, contact in enumerate(contacts):
match = prefix_pattern.match(contact)
if match:
prefix = match.group()
if prefix != current_electrode:
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, i - 1)
current_electrode = prefix
start_index = i
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, len(contacts) - 1)
return electrode_indices
def extract_numeric_part(s: str) -> Optional[str]:
"""
Extract the numeric part from a string (e.g., '17' from 'A17').
Parameters:
- s (str): A channel name string.
Returns:
- Optional[str]: The numeric part of the string, or None if no digits are found.
"""
matches = re.findall(r'\d+', s)
return matches[0] if matches else None
def plot_all_channels(data: np.ndarray, electrode_index: Dict[str, Tuple[int, int]], epoch: mne.Epochs,
selected_active_contacts_index: np.ndarray, start: int = 900, end: int = 1400) -> Figure:
"""
Plot each channel's high gamma power grouped by electrode.
Channels are arranged in a grid where the first column displays the electrode name and subsequent
subplots show the channel waveform. Active contacts are highlighted with a green border.
Parameters:
- data (np.ndarray): Array of shape [n_channels, n_times] (mean high gamma power).
- electrode_index (Dict[str, Tuple[int, int]]): Mapping of electrode names to their channel index range.
- epoch (mne.Epochs): MNE Epochs object (provides channel names).
- selected_active_contacts_index (np.ndarray): Array of indices for active channels.
- start (int): Starting index for the plotting window.
- end (int): Ending index for the plotting window.
Returns:
- Figure: Matplotlib figure object.
"""
n_channels, n_times = data.shape
# Calculate the number of channels per electrode group
ele_Num = [int(np.diff(n)[0]) + 1 for n in electrode_index.values()]
ele_name = list(electrode_index.keys())
# Define grid dimensions: one extra column for electrode labels
col = max(ele_Num) + 1
row = len(ele_Num)
ymax = 6
plot_number = 0
fig, axs = plt.subplots(row, col, figsize=(col * 0.7, row / 2))
axs = np.array(axs) if isinstance(axs, np.ndarray) else np.array([axs])
for i in range(axs.shape[0]):
for j in range(axs.shape[1]):
ax = axs[i, j]
ax.axis('off') # Hide axes by default
if j == 0:
# First column shows the electrode name
ax.text(0.5, 0.3, ele_name[i], fontsize=10, fontweight='bold', ha='center', va='center')
continue
if j <= ele_Num[i]:
# Plot the channel waveform for the current channel
channel_name = epoch.ch_names[plot_number]
channel_number = extract_numeric_part(channel_name)
ax.text(360, 3.5, channel_number, fontsize=7)
ax.plot(data[plot_number, start:end], color='k', alpha=0.5)
ax.plot([100, 100], [-1, ymax], color='red', alpha=0.5, linestyle='--', linewidth=1)
ax.set_ylim(-1, ymax)
ax.axis('on')
for spine in ax.spines.values():
spine.set_edgecolor('grey')
spine.set_linewidth(0.5)
if plot_number in selected_active_contacts_index:
for spine in ax.spines.values():
spine.set_edgecolor('green')
spine.set_linewidth(2)
plot_number += 1
ax.set_xticks([])
ax.set_yticks([])
plt.suptitle('Significant Channels (green), high gamma power', fontsize=15)
plt.tight_layout()
plt.show()
return fig
contacts = epoch.ch_names
electrode_indices = find_electrode_indices(contacts)
hg_power_mean = np.mean(hg_power, axis=0)
all_channel_figure = plot_all_channels(hg_power_mean, electrode_indices, epoch, active_index)
all_channel_figure.savefig(f"{img_path}HGP_all_channels.png", dpi=300)
plot active contacts¶
Plot active channels (mean and SEM)
def plot_active_channels(data: np.ndarray, active_index: np.ndarray, epoch: mne.Epochs, start: int = 900, end: int = 1400) -> Figure:
"""
Plot active channels in a grid layout showing mean high gamma power and standard error.
For each active channel, the function plots the mean waveform with a shaded region representing the SEM.
Parameters:
- data (np.ndarray): High gamma power data of shape [n_epochs, n_channels, n_times].
- active_index (np.ndarray): Array of indices for active channels.
- epoch (mne.Epochs): MNE Epochs object (provides channel names).
- start (int): Starting index for the time window to plot.
- end (int): Ending index for the time window to plot.
Returns:
- Figure: Matplotlib figure object.
"""
selected_active_contacts_index = active_index # alias for clarity
n_channels = len(selected_active_contacts_index)
col = 4
row = n_channels // col + (1 if n_channels % col else 0)
fig, axs = plt.subplots(row, col, figsize=(col * 2.5, row * 2))
axs = np.array(axs)
for i in range(axs.shape[0]):
for j in range(axs.shape[1]):
plot_number = i * col + j
ax = axs[i, j]
ax.axis('off')
if plot_number < n_channels:
ax.axis('on')
ax.set_xticks([])
if j == 0:
ax.set_ylabel('High Gamma Power (dB)', fontsize=8)
ax.set_xticks(np.arange(0, 501, 100))
ax.set_xticklabels(np.arange(0, 501, 100) - 100)
if plot_number + col >= n_channels:
ax.set_xlabel('Time (ms)')
else:
ax.set_xlabel('')
# Extract channel data and compute mean and SEM
channel_index = active_index[plot_number]
channel_data = data[:, channel_index, start:end]
channel_mean = channel_data.mean(axis=0)
channel_sem = stats.mstats.sem(channel_data, axis=0)
channel_x = np.arange(0, end - start)
ax.plot(channel_x, channel_mean, color='k', alpha=0.8)
ax.fill_between(channel_x, channel_mean + channel_sem, channel_mean - channel_sem,
color='k', alpha=0.2)
y_max = 6
ax.set_yticks([0, y_max//2, y_max])
ax.plot([100, 100], [-1, y_max], color='red', alpha=0.5, linestyle='--', linewidth=1)
ax.set_ylim(-1, y_max)
ax.set_title(f"{epoch.ch_names[channel_index]}")
for spine in ax.spines.values():
spine.set_edgecolor('grey')
spine.set_linewidth(0.5)
plt.suptitle('Significant Channels, high gamma power', fontsize=15)
plt.tight_layout()
plt.show()
return fig
active_contact_figure = plot_active_channels(hg_power, active_index, epoch)
active_contact_figure.savefig(f"{img_path}HGP_active_channel.png", dpi=300)
ERP¶
This section demonstrates the processing of ERP data by:
- Loading and resampling ERP epoch data.
- Plotting ERP waveforms for active channels.
Modules:
- Data Loader: load_and_resample_epoch_erp
- ERP Plotting: plot_active_channel_erp
Data Loader¶
def load_and_resample_epoch(epoch_file: str, resample_rate: int = 1000) -> Tuple[mne.Epochs, np.ndarray]:
"""
Load epoch data from a file and resample it.
Parameters:
- epoch_file: string, file path to the MNE epoch file.
- resample_rate: int, new sampling rate (default is 1000 Hz).
Returns:
- epoch: MNE Epochs object after resampling.
- epoch_data: numpy array of shape [n_epochs, n_channels, n_timepoints].
"""
# Load the epochs from file
epoch = mne.read_epochs(epoch_file)
# Resample the data to the desired sampling rate
epoch.resample(resample_rate)
# Get the data as a numpy array
epoch_data = epoch.copy().get_data()
return epoch, epoch_data
epoch, epoch_data = load_and_resample_epoch(epochERPFile_name)
Reading /home/haozhu/Desktop/demo/QUJ3YH24KM_epoch_erp-epo.fif ... Isotrak not found Found the data of interest: t = -100.10 ... 399.90 ms 0 CTF compensation matrices available Not setting metadata 551 matching events found No baseline correction applied 0 projection items activated
plot active contacts erp¶
def plot_active_channel_erp(data: np.ndarray,
selected_active_channel_names: list[str],
epoch: mne.Epochs,
plot_start: int = 0,
plot_end: int = 500) -> Figure:
"""
Plot the ERP (voltage over time) for active channels.
This function creates a grid of subplots where each subplot shows the mean ERP waveform
(with standard error shading) for an active channel. Instead of using channel indices,
it accepts a list of active channel names. For each channel, the function determines its
index in the epoch and then extracts the ERP data for plotting.
Parameters:
- data (np.ndarray): ERP epoch data with shape [n_epochs, n_channels, n_timepoints].
- selected_active_channel_names (List[str]): List of active channel names.
- epoch (mne.Epochs): MNE Epochs object (provides channel names).
- plot_start (int): Starting index for the time window to plot (default is 0).
- plot_end (int): Ending index for the time window to plot (default is 500).
Returns:
- Figure: Matplotlib figure object containing the ERP plots.
"""
# Determine the number of active channels based on the provided list
n_channels = len(selected_active_channel_names)
# Define grid dimensions: 4 columns; number of rows is computed based on number of active channels.
col = 4
row = n_channels // col + (1 if n_channels % col != 0 else 0)
# Create the figure and axes for subplots
fig, axs = plt.subplots(row, col, figsize=(col * 2.5, row * 2))
print("Axes shape:", np.shape(axs))
# Define the x-axis for the plot: samples within the specified time window
channel_x = np.arange(plot_end - plot_start)
# Set x-tick positions and labels (assuming index 0 corresponds to -100 ms and index 100 to 0 ms)
xtick_positions = np.arange(0, (plot_end - plot_start)+1, 100)
xtick_labels = (xtick_positions - 100).tolist()
ymax = 8 # Maximum voltage value for plotting (in arbitrary units)
# Loop over the subplot grid and plot ERP for each active channel
for i in range(row):
for j in range(col):
plot_number = i * col + j
# Handle cases when axs is 1D or 2D depending on number of rows/columns
ax = axs[i, j] if row > 1 else axs[j]
ax.axis('off')
if plot_number < n_channels:
ax.axis('on')
if j == 0:
ax.set_ylabel('Voltage (10µV)')
ax.set_xticks(xtick_positions)
ax.set_xticklabels(xtick_labels)
if i == row - 1:
ax.set_xlabel('Time (ms)')
# Get the channel name from the provided active channel list
channel_name = selected_active_channel_names[plot_number]
# Find the corresponding index in epoch.ch_names
try:
channel_index = epoch.ch_names.index(channel_name)
except ValueError:
raise ValueError(f"Channel name {channel_name} not found in epoch.ch_names.")
# Extract ERP data for the current channel across all trials in the defined window
channel_data = data[:, channel_index, plot_start:plot_end]
# Compute the mean ERP waveform and standard error across trials
channel_mean = channel_data.mean(axis=0) * 1e5
channel_sem = stats.mstats.sem(channel_data, axis=0) * 1e5
# Plot the mean ERP waveform
ax.plot(channel_x, channel_mean, color='k', alpha=0.8)
# Plot the shaded standard error region
ax.fill_between(channel_x, channel_mean + channel_sem, channel_mean - channel_sem,
color='k', alpha=0.2)
# Define x-axis limits (in samples)
xlims = [-20, 525]
# Define y-ticks for voltage display (for example, from -ymax to ymax in steps of 4)
channel_y = np.arange(-ymax, ymax + 0.01, 4).astype(int)
ax.set_yticks(channel_y)
ax.set_yticklabels(channel_y, fontsize=8)
# Draw a vertical dashed line at time 0 (which corresponds to index 100)
ax.plot([100, 100], [-ymax, ymax], color='k', linewidth=1, linestyle=(0, [1, 1]))
ax.set_ylim(-ymax, ymax)
ax.set_xlim(xlims)
# Set the subplot title to the active channel name
ax.set_title(f"{channel_name}")
# Draw a horizontal dashed line at 0 voltage
ax.plot(xlims, [0, 0], 'k', linewidth=1, linestyle=(0, [1, 1]))
# Hide the top and right spines for a cleaner appearance
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.suptitle('Significant Channels, ERP', fontsize=15)
plt.tight_layout()
plt.show()
return fig
active_channel = ['E1', 'E4', 'E5', 'E6', 'E7', 'E8', 'E9', 'E10', 'E11', "E'9", "E'10"]
# Plot ERP for the active channels using the ERP epoch data
active_erp_fig = plot_active_channel_erp(epoch_data, active_channel, epoch, plot_start=0, plot_end=500)
# Save the ERP figure
active_erp_fig.savefig(f"{img_path}ERP_active_channel.png", dpi=300)
Axes shape: (3, 4)
BBP¶
This section demonstrates the processing of BBP data by:
- Loading and resampling BBP epoch data.
- Computing normalized spectrograms
- Plotting spectrograms for active channels.
Data Loader¶
def load_and_resample_epoch(epoch_file: str, resample_rate: int = 1000) -> mne.Epochs:
"""
Load epoch data from a file and resample it.
Parameters:
- epoch_file: string, file path to the MNE epoch file.
- resample_rate: int, new sampling rate (default is 1000 Hz).
Returns:
- epoch: MNE Epochs object after resampling.
- epoch_data: numpy array of shape [n_epochs, n_channels, n_timepoints].
"""
# Load the epochs from file
epoch = mne.read_epochs(epoch_file)
# Resample the data to the desired sampling rate
epoch.resample(resample_rate)
return epoch
def extract_active_channels(epoch: mne.Epochs, active_channels: List[str]) -> np.ndarray:
"""
Extract data corresponding to active channels from the epoch object.
Parameters
----------
epoch : mne.Epochs
The MNE epochs object containing the sEEG data.
active_channels : List[str]
List of channel names to extract.
Returns
-------
data : np.ndarray
NumPy array of shape (n_trials, n_active_channels, n_samples) with the extracted data.
"""
data = epoch.copy().pick(active_channels).get_data()
return data
active_channel = ['E1', 'E4', 'E5', 'E6', 'E7', 'E8', 'E9', 'E10', 'E11', "E'9", "E'10"]
epoch = load_and_resample_epoch(epochBBPFile_name)
epoch_data = extract_active_channels(epoch, active_channel)
Reading /home/haozhu/Desktop/demo/QUJ3YH24KM_epoch_bbp-epo.fif ... Isotrak not found Found the data of interest: t = -100.10 ... 399.90 ms 0 CTF compensation matrices available Not setting metadata 551 matching events found No baseline correction applied 0 projection items activated
Spectrogram¶
Module for computing spectrograms from sEEG trial data using a sliding window DFT.
computing spectrograms¶
def compute_spectrogram(trial_data: np.ndarray, window: int, overlap: int, nfft: np.ndarray, sample_rate: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Compute a spectrogram for a single trial's data using a sliding window Discrete Fourier Transform (DFT).
This function mimics MATLAB's spectrogram(small_segment, window, overlap, nfft, sample_rate, 'yaxis')
by applying a Hamming window to each segment, computing the DFT at specified frequency points,
and then calculating the power spectrum.
Parameters
----------
trial_data : np.ndarray
1D array containing the signal from one trial (a single channel).
window : int
Length of the window (in samples) for the sliding DFT.
overlap : int
Number of samples overlapping between adjacent windows.
nfft : np.ndarray
1D array of frequency points (in Hz) at which to compute the DFT.
sample_rate : float
Sampling rate in Hz.
Returns
-------
S : np.ndarray
Complex spectrogram matrix of shape (len(nfft), n_windows) with DFT values.
F : np.ndarray
Frequency vector (in Hz) corresponding to the rows of S (same as input nfft).
T : np.ndarray
Time vector (in seconds) corresponding to the center of each window.
P : np.ndarray
Power spectrogram computed as the magnitude squared of S.
"""
# Compute the step size (non-overlapping samples)
step = window - overlap
if step < 1:
step = 1 # Ensure at least one sample step
# Determine the number of segments (windows) that fit into the trial data.
n_segments = (len(trial_data) - window) // step + 1
# Precompute the DFT exponent matrix.
n = np.arange(window) # Sample indices for one window
E = np.exp(-1j * 2 * np.pi * np.outer(nfft, n) / sample_rate) # Shape: (len(nfft), window)
# Create a Hamming window of the specified length.
w = np.hamming(window)
# Preallocate arrays for the spectrogram and time vector.
S = np.empty((len(nfft), n_segments), dtype=complex)
T = np.empty(n_segments)
# Process each segment.
for i in range(n_segments):
start = i * step
segment = trial_data[start: start + window]
segment = segment * w # Apply the Hamming window.
S[:, i] = E @ segment # Compute the DFT.
T[i] = (start + window / 2) / sample_rate # Center time of the window.
P = np.abs(S) ** 2 # Power spectrogram.
F = nfft # Frequency vector.
return S, F, T, P
"""
Module for processing sEEG data per channel and per trial in parallel.
This module leverages the compute_spectrogram() function.
"""
def process_channel_wrapper(args: Tuple) -> Tuple[int, np.ndarray, np.ndarray, np.ndarray]:
"""
Helper function to process a single channel across multiple trials.
For each trial, a segment of data is extracted, the spectrogram is computed,
and the power is normalized by converting to dB and subtracting a baseline.
Parameters
----------
args : Tuple
A tuple containing:
- j (int): Channel index.
- data (np.ndarray): 3D array with shape (n_trials, n_channels, n_samples).
- task_trials (int): Number of trials to process.
- time_start (int): Starting sample offset for segment extraction.
- window (int): Window length in samples for the spectrogram.
- overlap (int): Number of overlapping samples between windows.
- nfft (np.ndarray): Frequency points for the DFT.
- sample_rate (float): Sampling rate in Hz.
Returns
-------
j : int
Channel index.
local_spectro : np.ndarray
3D array of normalized power spectrograms for the channel with shape
(n_frequencies, n_time_bins, task_trials).
local_T : np.ndarray
Time vector (in seconds) computed for the channel.
local_F : np.ndarray
Frequency vector (in Hz) computed for the channel.
"""
j, data, task_trials, time_start, window, overlap, nfft, sample_rate = args
local_spectro_list = []
local_T = None
local_F = None
for i in range(task_trials):
# Extract a segment from the trial (samples 0 to time_start+399).
small_segment = data[i, j, :time_start+399].flatten()
S, F, T, P = compute_spectrogram(small_segment, window, overlap, nfft, sample_rate)
# Convert power to dB and perform baseline normalization.
logged_P = 10 * np.log10(P + np.finfo(float).eps)
base_db = np.mean(logged_P[:, 50:100], axis=1, keepdims=True)
norm_P = logged_P - base_db
local_spectro_list.append(norm_P)
if i == 0:
local_T = T
local_F = F
local_spectro = np.stack(local_spectro_list, axis=-1)
return j, local_spectro, local_T, local_F
def process_data(data: np.ndarray, good_channels: int, task_trials: int, time_start: int,
window: int, overlap: int, nfft: np.ndarray, sample_rate: float
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Process the data for each channel and trial in parallel by computing normalized spectrograms.
Parameters
----------
data : np.ndarray
3D array of shape (n_trials, n_channels, n_samples).
good_channels : int
Number of channels to process.
task_trials : int
Number of trials to process.
time_start : int
Starting sample offset (e.g., 100). The segment is extracted from 0 to time_start+399.
window : int
Window length (in samples) for spectrogram computation.
overlap : int
Number of overlapping samples between consecutive windows.
nfft : np.ndarray
Frequency points (in Hz) at which the DFT is computed.
sample_rate : float
Sampling rate in Hz.
Returns
-------
spectro : np.ndarray
4D array of normalized power spectrograms with shape
(n_frequencies, n_time_bins, task_trials, good_channels).
T_values : np.ndarray
2D array where each column is the time vector for a channel.
F_values : np.ndarray
2D array where each column is the frequency vector for a channel.
"""
start_time = time.time()
# Determine output dimensions using the first trial and first channel.
sample_segment = data[0, 0, :time_start+399].flatten()
_, F_sample, T_sample, P_sample = compute_spectrogram(sample_segment, window, overlap, nfft, sample_rate)
n_frequencies, n_time_bins = P_sample.shape
n_T = len(T_sample)
n_F = len(F_sample)
# Preallocate output arrays.
spectro = np.full((n_frequencies, n_time_bins, task_trials, good_channels), np.nan)
T_values = np.full((n_T, good_channels), np.nan)
F_values = np.full((n_F, good_channels), np.nan)
# Prepare arguments for parallel processing.
args_list = [
(j, data, task_trials, time_start, window, overlap, nfft, sample_rate)
for j in range(good_channels)
]
# Process channels in parallel.
with ProcessPoolExecutor() as executor:
futures = [executor.submit(process_channel_wrapper, args) for args in args_list]
for future in as_completed(futures):
j, local_spectro, local_T, local_F = future.result()
spectro[:, :, :, j] = local_spectro
T_values[:, j] = local_T
F_values[:, j] = local_F
print(f"Processed channel {j+1} of {good_channels}")
end_time = time.time()
print(f"Total processing time: {end_time - start_time:.2f} seconds")
return spectro, T_values, F_values
# Define spectrogram and processing parameters.
good_channels: int = epoch_data.shape[1] # Number of channels to process.
task_trials: int = epoch_data.shape[0] # Number of trials.
time_start: int = 100 # Starting sample offset.
window: int = 40 # Window length in samples.
overlap: int = 38 # Number of overlapping samples.
nfft: np.ndarray = np.arange(1, 204, 5) # Frequency vector (1 Hz to ~200 Hz).
sample_rate: int = 1000 # Sampling rate in Hz.
# For example, data = np.random.randn(task_trials, total_channels, n_samples)
spectro, T_values, F_values = process_data(epoch_data, good_channels, task_trials, time_start, window, overlap, nfft, sample_rate)
Processed channel 1 of 11 Processed channel 2 of 11 Processed channel 3 of 11 Processed channel 4 of 11 Processed channel 5 of 11 Processed channel 8 of 11 Processed channel 9 of 11 Processed channel 7 of 11 Processed channel 6 of 11 Processed channel 11 of 11 Processed channel 10 of 11 Total processing time: 1.51 seconds
plot spectrograms¶
def plot_spectrogram(spectro: np.ndarray, T_values: np.ndarray, F_values: np.ndarray,
channel_titles: List[str] = None, save: bool = False,
top_color: float = 2, bottom_color: float = 0, n_cols: int = 4) -> None:
"""
Plot the spectrogram for multiple channels by averaging over trials.
Parameters
----------
spectro : np.ndarray
4D array of normalized power spectrograms with shape
(n_frequencies, n_time_bins, task_trials, n_channels).
T_values : np.ndarray
2D array where each column is the time vector (in seconds) for a channel.
F_values : np.ndarray
2D array where each column is the frequency vector (in Hz) for a channel.
channel_titles : List[str], optional
List of channel titles. Defaults to "Channel 1", "Channel 2", etc. if not provided.
save : bool, optional
If True, the figure is saved to the path specified by img_path (default is False).
top_color : float, optional
Upper limit for color scaling (default is 2).
bottom_color : float, optional
Lower limit for color scaling (default is 0).
n_cols : int, optional
Number of subplot columns (default is 4).
Returns
-------
None
The function plots the spectrogram and optionally saves the figure.
"""
# Adjust time axis: shift by -100ms and convert seconds to milliseconds.
T_correct = (T_values - 0.1) * 1000 # in ms
n_channels = spectro.shape[-1]
n_rows = -(-n_channels // n_cols) # Ceiling division
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 2.5, n_rows * 2))
axs = axs.flatten()
for i in range(n_channels):
select_P = spectro[:, :, :, i]
averaged_P = np.nanmean(select_P, axis=2)
ax = axs[i]
im = ax.imshow(averaged_P, aspect='auto',
extent=[T_correct.min(), T_correct.max(), F_values.min(), F_values.max()],
cmap='cividis', vmin=bottom_color, vmax=top_color, origin='lower')
if channel_titles is not None and i < len(channel_titles):
ax.set_title(channel_titles[i], fontsize=12)
else:
ax.set_title(f"Channel {i+1}", fontsize=12)
ax.set_ylim([50, 200])
ax.set_yticks([50, 100, 150, 200])
if i % n_cols == 0:
ax.set_ylabel('Frequency (Hz)')
else:
ax.set_ylabel('')
ax.set_xticks([0, 100, 200, 300])
if i + n_cols >= n_channels:
ax.set_xlabel('Time (ms)')
else:
ax.set_xlabel('')
for j in range(n_channels, len(axs)):
axs[j].axis('off')
plt.suptitle('Significant Channels, Spectrogram', fontsize=15)
fig.tight_layout(rect=[0, 0, 0.9, 1])
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
cbar = fig.colorbar(im, cax=cbar_ax, ticks=[0, 1, 2, 3, 4])
cbar.set_label('Normalized Power (dB)', fontsize=10)
if save:
save_filename = f"{img_path}Spectro_active_channel.png"
plt.savefig(save_filename, dpi=300, bbox_inches='tight')
print(f"Figure saved to {save_filename}")
plt.show()
plot_spectrogram(spectro, T_values, F_values, channel_titles=active_channel, save=True)
Figure saved to figures/Spectro_active_channel.png