sEEG data preprocessing pipelineĀ¶
Contributor: Hao Zhu (haozhu@cuhk.edu.hk)
Create Date: 2024/10/02
This tutorial is based on:
- Ubuntu 24.04.1 LTS 64-bit
- Python 3.12.4
- mne 1.8.0
- re 2.2.1
- h5py 3.11.0
- numpy 1.26.4
- seaborn 0.13.2
- matplotlib 3.8.4
- scipy 1.13.1
- pingouin 0.5.5
IntroductionĀ¶
This tutorial focuses on processing stereotactic EEG (sEEG) data using Python. We'll cover several key stages, including data loading, preprocessing, feature extraction, and visualization, using popular libraries such as MNE, NumPy, and scipy.
Different labs use different pipelines and methods to do preprocessing. This pipeline combines multiple pipelines from other labs and is tested to be a good fit for sEEG data collected from Shenzhen and Guangzhou.
In this tutorial, I use a sample dataset from a real subject (sz04) collected from Shenzhen Second People's Hospital.
The raw data collected from the hospital's Nicolet acquisition machine is in EDF format. It is cut, the bad channel was dropped, and transformed to FIF format with MNE-python. The event is recorded in the "TRIG" channel, and pulses with different magnitudes represent different events (e.g., a high magnitude pulse normally means PC sent marker). Extract events from the "TRIG" channel and store it in .eve format (check the example .eve to get a sense of the event format: [timestamp, 0, event_id]).
Two files are required for this tutorial:
- event file ending in .eve
- Cut raw data file ending in .fif
Loading modulesĀ¶
import mne, re, h5py
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import scipy.signal as signal
from scipy.stats import sem
from tqdm import tqdm
import pandas as pd
import pingouin as pg
file pathwaysĀ¶
data_path = '/home/haozhu/brainstormdb/Test_data/'
img_path = '/home/haozhu/brainstormdb/Test_data/Figures/'
subject_id = 'sz04'
fs_subject = 'fsaverage'
fs_path = '/home/haozhu/freesurfer/subjects/'
pre-processingĀ¶
load and drop epileptic contactsĀ¶
subject_Datadir = f'{data_path}/{subject_id}_2022_09_17_rawMP_s1_'
eventFile_name = f'{subject_Datadir}events.eve'
rawMPFile_name = f'{subject_Datadir}ieeg.fif'
epochFile_name = f'{subject_Datadir}epoch-epo.fif'
epileptic_contacts = ['A1', 'A2', 'A3', 'A4', 'A5',
'H1', 'H2', 'H3', 'H4', 'H5']
#The epileptic contacts that were determined by neuroelectrophysiologist
We begin by loading the raw sEEG data and event file using MNE.
And drop the pre-determined epileptic contacts
raw = mne.io.read_raw_fif(rawMPFile_name, preload=True)
events = mne.read_events(eventFile_name)
raw.drop_channels(['TRIG'])
#drop "TRIG" channel as it's no longer needed. Events should already be extracted before starting preprocessing.
raw.drop_channels(epileptic_contacts)
Opening raw data file /home/haozhu/brainstormdb/Test_data//sz04_2022_09_17_rawMP_s1_ieeg.fif... Isotrak not found Range : 458000 ... 994000 = 458.000 ... 994.000 secs Ready. Reading 0 ... 536000 = 0.000 ... 536.000 secs...
General | ||
---|---|---|
Filename(s) | sz04_2022_09_17_rawMP_s1_ieeg.fif | |
MNE object type | Raw | |
Measurement date | 2022-09-17 at 10:57:27 UTC | |
Participant | Unknown | |
Experimenter | Unknown | |
Acquisition | ||
Duration | 00:08:56 (HH:MM:SS) | |
Sampling frequency | 1000.00 Hz | |
Time points | 536,001 | |
Channels | ||
EEG | ||
Head & sensor digitization | Not available | |
Filters | ||
Highpass | 0.00 Hz | |
Lowpass | 500.00 Hz |
filtersĀ¶
sEEG data often contains powerline noise (e.g., 50 Hz). We apply a notch filter to remove this noise.
We further apply a band-pass filter to focus on frequencies of interest., which is commonly less than 200hz
raw.notch_filter([50,100,150,200])
raw.filter(l_freq=0.1, h_freq=200)
Filtering raw data in 1 contiguous segment Setting up band-stop filter FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandstop filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower transition bandwidth: 0.50 Hz - Upper transition bandwidth: 0.50 Hz - Filter length: 6601 samples (6.601 s)
[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.5s [Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 1.1s
Filtering raw data in 1 contiguous segment Setting up band-pass filter from 0.1 - 2e+02 Hz FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandpass filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower passband edge: 0.10 - Lower transition bandwidth: 0.10 Hz (-6 dB cutoff frequency: 0.05 Hz) - Upper passband edge: 200.00 Hz - Upper transition bandwidth: 50.00 Hz (-6 dB cutoff frequency: 225.00 Hz) - Filter length: 33001 samples (33.001 s)
[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.2s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.8s [Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 1.7s
General | ||
---|---|---|
Filename(s) | sz04_2022_09_17_rawMP_s1_ieeg.fif | |
MNE object type | Raw | |
Measurement date | 2022-09-17 at 10:57:27 UTC | |
Participant | Unknown | |
Experimenter | Unknown | |
Acquisition | ||
Duration | 00:08:56 (HH:MM:SS) | |
Sampling frequency | 1000.00 Hz | |
Time points | 536,001 | |
Channels | ||
EEG | ||
Head & sensor digitization | Not available | |
Filters | ||
Highpass | 0.10 Hz | |
Lowpass | 200.00 Hz |
re-reference with ESRĀ¶
The primary purpose of re-referencing is to mitigate the impact of the reference electrode on the recorded sEEG signals. In this sample dataset, sEEG signals were online referenced with contact (1 or 2) in the white matter (monopolar method). This reference electrode can pick up artifacts that affect all other channels, potentially distorting the true brain signals of interest. And it reduces biases introduced by the original reference electrode and minimizes spurious correlations between channels.
It can help improve the overall signal quality and signal-to-noise ratio (SNR). By choosing an appropriate re-referencing method, researchers can reduce common noise and artifacts that may be present in the original reference. You may refer to this paper for more comprehensive info about which method to use: https://pmc.ncbi.nlm.nih.gov/articles/PMC6495648/
I selected Electrode Shaft Reference (ESR: re-referenced each channel to the average signal of all channels on the same shaft) for the dataset collected from Shenzhen after scrutinizing all possible methods.
def find_electrode_indices(contacts):
'''
This function loops through all contacts and returns the indices where each electrode shaft starts and ends.
'''
electrode_indices = {}
current_electrode = None
start_index = 0
# Regex to extract the prefix (non-digit characters) from each contact
prefix_pattern = re.compile(r"^[A-Z]+'?")
for i, contact in enumerate(contacts):
# Extract the prefix using regex
match = prefix_pattern.match(contact)
if match:
prefix = match.group()
if prefix != current_electrode:
# If it's not the first electrode, save the previous electrode's indices
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, i - 1)
# Update current electrode and start index
current_electrode = prefix
start_index = i
# Add the last electrode's indices
if current_electrode is not None:
electrode_indices[current_electrode] = (start_index, len(contacts) - 1)
return electrode_indices
contacts = raw.info['ch_names']
electrode_indices = find_electrode_indices(contacts)
electrode_indices
{'A': (0, 9), 'H': (10, 19), 'J': (20, 37), 'OB': (38, 52), 'V': (53, 62), 'X': (63, 77), 'Z': (78, 92), 'O': (93, 102), 'Y': (103, 118), 'E': (119, 130), 'F': (131, 142), 'W': (143, 154), 'M': (155, 170), "X'": (171, 184), "Y'": (185, 200)}
'''
Iterate over each electrode and compute the average signal,
which is then subtracted from the original signal.
'''
raw_data = raw.copy().get_data()
electrode_shaft_average = np.zeros(raw_data.shape)
for electrode in electrode_indices.keys():
electrode_start, electrode_end = electrode_indices[electrode]
esr = np.mean(raw_data[electrode_start:electrode_end+1],axis=0)
electrode_shaft_average[electrode_start:electrode_end+1] = esr
raw._data = raw_data - electrode_shaft_average
epochingĀ¶
event_id 4: passive listening audio trigger
The length of the epoch is determined based on the task design and follow-up analysis, e.g., for time-frequency analysis, to avoid edge effect, the epoch should be at least 3-5 cycles longer than the lowest frequency of interest.
Baseline correction normally should be set to "True". Here, it is set to "None" due to my specific task design.
Detrend is recommended, as it removes drift over a long period of time.
epoch = mne.Epochs(raw, events,tmin=-2.0, tmax=2.0, baseline=None,detrend=1,
reject=None, preload=True)
Not setting metadata 271 matching events found No baseline correction applied 0 projection items activated Using data from preloaded Raw for 271 events and 4001 original time points ...
/tmp/ipykernel_12366/4239655886.py:1: RuntimeWarning: The events passed to the Epochs constructor are not chronologically ordered. epoch = mne.Epochs(raw, events,tmin=-2.0, tmax=2.0, baseline=None,detrend=1,
0 bad epochs dropped
drop improbable epochĀ¶
- greater than 350uv
- 4 sd away from the mean across all the epochs
epoch_data = epoch.get_data(copy=True)
contact_number = epoch_data.shape[1]
epoch_data_mean = epoch_data.mean(axis=0)
epoch_data_sd = epoch_data.std(axis=0)
epoch_thredhold = 350e-4
epoch_peak_value = np.abs(epoch_data).max(axis=2)
for ind,each_epoch in enumerate(epoch_peak_value):
improbable_contacts_ratio = len(set(np.where(each_epoch > epoch_thredhold)[0])) / contact_number
if improbable_contacts_ratio > 0.5:
print("peak",ind,improbable_contacts_ratio)
for ind,each_epoch in enumerate(epoch_data):
improbable_contacts_ratio = len(set(np.where(each_epoch > epoch_data_mean + 4*epoch_data_sd)[0])) / contact_number
if improbable_contacts_ratio > 0.5:
print('sd',ind,improbable_contacts_ratio)
# unmute only when above output isn't blank
# improbable_epoch = []
# epoch.drop(improbable_epoch)
save to epoch fileĀ¶
epoch.save(epochFile_name)
Active contact selectionĀ¶
Not all contacts are responsive to auditory stimuli. Therefore, we need an objective measurement to select those that are responsive.
load epochĀ¶
epochFile_name = f'{subject_Datadir}epoch-epo.fif'
epoch = mne.read_epochs(epochFile_name)
Reading /home/haozhu/brainstormdb/Test_data/sz04_2022_09_17_rawMP_s1_epoch-epo.fif ... Isotrak not found Found the data of interest: t = -2000.00 ... 2000.00 ms 0 CTF compensation matrices available
/tmp/ipykernel_12366/3086389789.py:1: RuntimeWarning: The events passed to the Epochs constructor are not chronologically ordered. epoch = mne.read_epochs(epochFile_name)
Not setting metadata 271 matching events found No baseline correction applied 0 projection items activated
/tmp/ipykernel_12366/3086389789.py:1: RuntimeWarning: The events passed to the Epochs constructor are not chronologically ordered. epoch = mne.read_epochs(epochFile_name)
filter high gamma responseĀ¶
The target signal of interest is called the high gamma band, which is 70-150 Hz. Itās been used for the past decades as a local index of cortical activity, and it correlates with underlying spiking activity as well as the underlying BOLD activity. It is extremely robust across trials, so we extract epochs in the high gamma band.
epoch_high_gamma = epoch.copy().filter(70,150)
epoch_data = epoch_high_gamma["4"].get_data(copy=True)
Setting up band-pass filter from 70 - 1.5e+02 Hz FIR filter parameters --------------------- Designing a one-pass, zero-phase, non-causal bandpass filter: - Windowed time-domain design (firwin) method - Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation - Lower passband edge: 70.00 - Lower transition bandwidth: 17.50 Hz (-6 dB cutoff frequency: 61.25 Hz) - Upper passband edge: 150.00 Hz - Upper transition bandwidth: 37.50 Hz (-6 dB cutoff frequency: 168.75 Hz) - Filter length: 189 samples (0.189 s)
[Parallel(n_jobs=1)]: Done 17 tasks | elapsed: 0.0s [Parallel(n_jobs=1)]: Done 71 tasks | elapsed: 0.0s [Parallel(n_jobs=1)]: Done 161 tasks | elapsed: 0.0s [Parallel(n_jobs=1)]: Done 287 tasks | elapsed: 0.0s [Parallel(n_jobs=1)]: Done 449 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 647 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 881 tasks | elapsed: 0.1s [Parallel(n_jobs=1)]: Done 1151 tasks | elapsed: 0.2s [Parallel(n_jobs=1)]: Done 1457 tasks | elapsed: 0.2s [Parallel(n_jobs=1)]: Done 1799 tasks | elapsed: 0.3s [Parallel(n_jobs=1)]: Done 2177 tasks | elapsed: 0.3s [Parallel(n_jobs=1)]: Done 2591 tasks | elapsed: 0.4s [Parallel(n_jobs=1)]: Done 3041 tasks | elapsed: 0.4s [Parallel(n_jobs=1)]: Done 3527 tasks | elapsed: 0.5s [Parallel(n_jobs=1)]: Done 4049 tasks | elapsed: 0.5s [Parallel(n_jobs=1)]: Done 4607 tasks | elapsed: 0.6s [Parallel(n_jobs=1)]: Done 5201 tasks | elapsed: 0.7s [Parallel(n_jobs=1)]: Done 5831 tasks | elapsed: 0.8s [Parallel(n_jobs=1)]: Done 6497 tasks | elapsed: 0.9s [Parallel(n_jobs=1)]: Done 7199 tasks | elapsed: 0.9s [Parallel(n_jobs=1)]: Done 7937 tasks | elapsed: 1.0s [Parallel(n_jobs=1)]: Done 8711 tasks | elapsed: 1.1s [Parallel(n_jobs=1)]: Done 9521 tasks | elapsed: 1.2s [Parallel(n_jobs=1)]: Done 10367 tasks | elapsed: 1.4s [Parallel(n_jobs=1)]: Done 11249 tasks | elapsed: 1.5s [Parallel(n_jobs=1)]: Done 12167 tasks | elapsed: 1.6s [Parallel(n_jobs=1)]: Done 13121 tasks | elapsed: 1.7s [Parallel(n_jobs=1)]: Done 14111 tasks | elapsed: 1.8s [Parallel(n_jobs=1)]: Done 15137 tasks | elapsed: 2.0s [Parallel(n_jobs=1)]: Done 16199 tasks | elapsed: 2.1s [Parallel(n_jobs=1)]: Done 17297 tasks | elapsed: 2.2s [Parallel(n_jobs=1)]: Done 18431 tasks | elapsed: 2.4s [Parallel(n_jobs=1)]: Done 19601 tasks | elapsed: 2.5s [Parallel(n_jobs=1)]: Done 20807 tasks | elapsed: 2.7s [Parallel(n_jobs=1)]: Done 22049 tasks | elapsed: 2.8s [Parallel(n_jobs=1)]: Done 23327 tasks | elapsed: 3.0s [Parallel(n_jobs=1)]: Done 24641 tasks | elapsed: 3.2s [Parallel(n_jobs=1)]: Done 25991 tasks | elapsed: 3.3s [Parallel(n_jobs=1)]: Done 27377 tasks | elapsed: 3.5s [Parallel(n_jobs=1)]: Done 28799 tasks | elapsed: 3.7s [Parallel(n_jobs=1)]: Done 30257 tasks | elapsed: 3.9s [Parallel(n_jobs=1)]: Done 31751 tasks | elapsed: 4.0s [Parallel(n_jobs=1)]: Done 33281 tasks | elapsed: 4.2s [Parallel(n_jobs=1)]: Done 34847 tasks | elapsed: 4.4s [Parallel(n_jobs=1)]: Done 36449 tasks | elapsed: 4.6s [Parallel(n_jobs=1)]: Done 38087 tasks | elapsed: 4.8s [Parallel(n_jobs=1)]: Done 39761 tasks | elapsed: 5.1s [Parallel(n_jobs=1)]: Done 41471 tasks | elapsed: 5.3s [Parallel(n_jobs=1)]: Done 43217 tasks | elapsed: 5.5s [Parallel(n_jobs=1)]: Done 44999 tasks | elapsed: 5.7s [Parallel(n_jobs=1)]: Done 46817 tasks | elapsed: 5.9s [Parallel(n_jobs=1)]: Done 48671 tasks | elapsed: 6.2s [Parallel(n_jobs=1)]: Done 50561 tasks | elapsed: 6.4s [Parallel(n_jobs=1)]: Done 52487 tasks | elapsed: 6.7s [Parallel(n_jobs=1)]: Done 54449 tasks | elapsed: 6.9s
calculate hg power of passive listeningĀ¶
def get_baseline_timelag(epoch_object):
#specific to my task design, do NOT take this as standard procedure.
#Normally, the baseline is set to ~200ms before stimulus onset.
baseline_timelag = {4:[],7:[],11:[],2:[],8:[],12:[]}
epoch_event = epoch_object.events
for c_index in tqdm(baseline_timelag.keys()):
if c_index in [2,4]:
minus_number = 1
elif c_index in [7,11,8,12]:
minus_number = 2
for event in epoch_object[f'{c_index}'].events:
e_abs_index = np.where(epoch_event[:,0]==event[0])[0]
e_visual = epoch_object.events[e_abs_index-minus_number][0]
time_lag = (e_visual[0] - event[0]) / 1000
baseline_timelag[c_index].append(time_lag)
return baseline_timelag
baseline_timelag = get_baseline_timelag(epoch)
100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 6/6 [00:01<00:00, 4.18it/s]
the envelope of the filtered signal was derived using a Hilbert transform. This transformed signal was then logged, normalized using a z-score, smoothed, exponentiated, and recentered to zero (by subtracting 1).
def calculate_broadband_power(epochs_data):
fs = 1000
n_epochs, n_channels, n_time = epochs_data.shape
broadband_power = np.zeros_like(epochs_data)
baseline_power = np.zeros((n_epochs, n_channels, 200))
for i in tqdm(range(n_epochs)):
for j in range(n_channels):
power_series = np.zeros(n_time)
# hilbert transform to get analytic power
analytic_signal = signal.hilbert(epochs_data[i, j])
power = np.abs(analytic_signal)**2
power_series += power
# Log transformation
power_series = np.log(power_series + 1e-20)
# Z-scoring to baseline -250ms to -50ms before visual onset
baseline_onset = int(baseline_timelag[4][i] * fs) + 2000
baseline_data = power_series[baseline_onset-250:baseline_onset-50]
baseline_mean = np.mean(baseline_data)
baseline_std = np.std(baseline_data)
power_series = (power_series - baseline_mean) / baseline_std
# Smoothing - simple moving average
window_size = int(fs * 0.15) # 0.15 second window
power_series = np.convolve(power_series, np.ones(window_size)/window_size, mode='same')
# Exponentiation
power_series = np.exp(power_series)
# Centering at zero (subtract 1)
broadband_power[i, j] = power_series - 1
baseline_power[i, j] = power_series[baseline_onset-250:baseline_onset-50] - 1
return broadband_power, baseline_power
hg_power, hg_pre_response = calculate_broadband_power(epoch_data)
100%|āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā| 20/20 [00:01<00:00, 16.58it/s]
Visualization of the z-scored, average across epoch, channel data. each trace represent a contact.
X-axis: 0 denotes stimulus onset, scale in ms.
Y-aixs: z-score
for n in range(epoch_data.shape[1]):
plt.plot(hg_power.mean(axis=0)[n][1000:3000])
plt.xticks(ticks=np.arange(0,2001,250),labels=np.arange(0,2001,250)-1000)
plt.show()
select active contact with t testĀ¶
In the passive listening task, contacts demonstrated a significant enhancement in response (p < 0.01, paired t-test, Bonferroni corrected) from the onset of the auditory probe (0 to 0.5 seconds) compared to a baseline period (ā250 to ā50 ms before the visual cue onset) were selected.
Additionally, you can choose the criteria of a high signal-to-noise atio (z > 1.3) during post-stimulus period as selection criteria.
def calculate_response_windows(data, fs, onsets):
post_start = onsets + int(0 * fs)
post_end = onsets + int(0.5 * fs)
# Averaging response across all epochs for each channel
post_responses = data[:, :, post_start:post_end].mean(axis=2)
return post_responses
def perform_paired_t_test(pre_responses, post_responses, alpha=0.01):
n_channels = pre_responses.shape[1]
results = pd.DataFrame(index=range(n_channels), columns=['T', 'p-val', 'p-corrected', 'significant'])
for channel in range(n_channels):
# Perform paired t-tests for each channel
res = pg.ttest(pre_responses[:,channel], post_responses[:,channel], paired=True)
results.loc[channel, 'T'] = res['T'].values[0]
results.loc[channel, 'p-val'] = res['p-val'].values[0]
# Bonferroni correction
results['p-corrected'] = results['p-val'].astype(float) * n_channels
results['p-corrected'] = results['p-corrected'].apply(lambda x: min(x, 1)) # Corrected p-values capped at 1
results['significant'] = results['p-corrected'] < alpha
return results
post_responses = calculate_response_windows(hg_power, 1000, 2000)
pre_responses = hg_pre_response.mean(axis=2)
test_results = perform_paired_t_test(pre_responses, post_responses)
active_index = np.where(test_results['significant'] == True)[0]
for ind in active_index:
print(epoch.ch_names[ind])
Z11 E5 W2 M8 M9 M10 M11 M12 M14 M15
Visually check every selected contact, its waveform over time, and distribution over trials (check whether some outlier trials cause it to be significant).
for n in active_index:
plt.figure(figsize=(6,2))
plt.plot(hg_power.mean(axis=0)[n][1800:3000])
plt.title(f'{n} {epoch.ch_names[n]}')
plt.ylim(-.5,2)
plt.show()
plt.figure(figsize=(8,2))
ax = sns.heatmap(hg_power[:,n,1800:3000])
ax.set_xticks([200,300,400,500,600,700],labels=[0,100,200,300,400,500])
plt.show()
selected_active_contacts = ["Z11","E5","W2","M8","M9","M10","M11","M12","M14","M15"]
selected_active_contacts_index = [m for m,n in enumerate(epoch.ch_names) if n in selected_active_contacts]
selected_active_contacts_index
[88, 123, 144, 162, 163, 164, 165, 166, 168, 169]
plot active contacts within all channelsĀ¶
def plot_channels(data):
"""
Plot each channel on a separate subplot within a single figure.
Parameters:
- data: numpy array of shape [n_channels, n_times]
"""
n_channels, n_times = data.shape
col = 7
row = n_channels // col + 1
fig, axs = plt.subplots(row, col, figsize=(7, 10))
print(axs.shape)
for i in range(axs.shape[0]):
for j in range(axs.shape[1]):
plot_number = i*col+j
ax = axs[i,j]
ax.axis('off')
if plot_number < data.shape[0]:
ax.plot(data[plot_number], color='k', alpha=0.5)
ax.plot([2000,2000],[-1,5],color='red',
alpha=0.5,linestyle='--',linewidth=1)
ax.set_ylim(-1,5)
ax.axis('on')
for spine in ax.spines.values():
spine.set_edgecolor('grey')
spine.set_linewidth(0.5)
if plot_number in selected_active_contacts_index:
for spine in ax.spines.values():
spine.set_edgecolor('green')
spine.set_linewidth(1.5)
ax.set_xticks(ticks=[])
ax.set_yticks(ticks=[])
# plt.tight_layout()
plt.show()
return fig
# Example usage:
hg_power_mean = np.mean(hg_power,axis=0)
all_channel_figure = plot_channels(hg_power_mean)
(29, 7)
all_channel_figure.savefig(f"{img_path}{subject_id}_all_channels.png",dpi=300)
plot active contactsĀ¶
def plot_active_channels(data):
"""
Plot each channel on a separate subplot within a single figure.
Parameters:
- data: numpy array of shape [n_channels, n_times]
"""
n_channels = len(selected_active_contacts_index)
n_times = data.shape[1]
plot_start, plot_end = 1800, 2800 #-200ms to 800ms
col = 3
row = n_channels // col + 1
fig, axs = plt.subplots(row, col, figsize=(10, 8))
print(axs.shape)
for i in range(axs.shape[0]):
for j in range(axs.shape[1]):
plot_number = i*col+j
ax = axs[i,j]
ax.axis('off')
if plot_number < n_channels:
ax.axis('on')
ax.set_xticks(ticks=[])
if j == 0:
ax.set_ylabel('z-score')
ax.set_xticks(ticks=np.arange(0,1001,200),labels=np.arange(0,1001,200)-200)
if i == row - 1:
ax.set_xlabel('Time (ms)')
channel_index = selected_active_contacts_index[plot_number]
ax.plot(data.mean(axis=0)[channel_index,plot_start:plot_end],
color='#62B197', alpha=1)
for e in data:
ax.plot(e[channel_index,plot_start:plot_end],color='k',alpha=0.05)
y_max = np.round(data[:,channel_index,plot_start:plot_end].max()) + 1
ax.set_yticks(ticks=[0,y_max//2,y_max])
ax.plot([200,200],[-1,y_max],color='red',
alpha=0.5,linestyle='--',linewidth=1)
ax.set_ylim(-1,y_max)
ax.set_title(f"{epoch.ch_names[channel_index]}")
for spine in ax.spines.values():
spine.set_edgecolor('grey')
spine.set_linewidth(0.5)
plt.tight_layout()
plt.show()
return fig
active_contact_figure = plot_active_channels(hg_power)
(4, 3)
active_contact_figure.savefig(f"{img_path}{subject_id}_active_auditory_channel.png",dpi=300)