"""
MCSData.py
Experimental data loader/handler for Multi Channel Systems (.h5) recordings.
This version is a *minimal-API-change* refactor of your previous MCSData to fit the
new architecture:
- MCSData now inherits from EData (and thus AData).
- All existing public methods + prototypes are preserved.
- Methods already provided by AData/EData are kept as thin wrappers for backward
compatibility (so user code/tests keep working).
- tracked_operation / ProcessingHistory live in processing_history.py (imported).
No intentional changes to user-facing behavior or default parameters.
"""
from __future__ import annotations
import os
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import h5py
import matplotlib.pyplot as plt
import numpy as np
import spikeinterface.extractors as se
import spikeinterface.preprocessing as pre
import spikeinterface as si
from matplotlib.widgets import CheckButtons
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from .edata import EData
from .processing_history import DatasetInfo, ProcessingHistory, tracked_operation
# ---------------------------------------------------------------------
# Helper functions (kept public; re-exported from tad.__init__)
# ---------------------------------------------------------------------
def on_delta_t(digital_recording, triggers, fsample: float, delta_t: float) -> None:
"""
Create stimulation ON triggers at each rising edge + delta_t.
Parameters
----------
digital_recording : np.ndarray
Digital trace (samples).
triggers : Triggers
Trigger container (in-place updates).
fsample : float
Sampling frequency (Hz).
delta_t : float
Time after each rising edge at which to create a stimulation ON slot (s).
"""
rising_edges = np.where(np.diff(digital_recording.astype(int)) == 1)[0] + 1
for idx in rising_edges:
start = idx / fsample
end = start + float(delta_t)
triggers.add_trigger(start=start, end=end, ID="stim_ON", blank=False)
def on_off_interpretor(digital_recording, triggers, fsample: float) -> None:
"""
Interpret digital trace as ON/OFF blocks and populate triggers.
Parameters
----------
digital_recording : np.ndarray
Digital trace (samples).
triggers : Triggers
Trigger container (in-place updates).
fsample : float
Sampling frequency (Hz).
"""
rising_edges = np.where(np.diff(digital_recording.astype(int)) == 1)[0] + 1
falling_edges = np.where(np.diff(digital_recording.astype(int)) == -1)[0] + 1
if len(rising_edges) == 0 or len(falling_edges) == 0:
return
# Align edges: each rising should have a following falling
if falling_edges[0] < rising_edges[0]:
falling_edges = falling_edges[1:]
n = min(len(rising_edges), len(falling_edges))
rising_edges = rising_edges[:n]
falling_edges = falling_edges[:n]
for i in range(n):
triggers.add_trigger(
start=rising_edges[i] / fsample,
end=falling_edges[i] / fsample,
ID="stim_ON",
blank=False,
)
def _raster_artifacts(raster_obj: Any) -> Dict[str, Any]:
"""
Optional artifact extractor for tracked_operation on get_raster().
Kept small and JSON-friendly.
"""
artifacts: Dict[str, Any] = {}
try:
if raster_obj is None:
return artifacts
for key in ("tstart", "tstop", "n_spikes", "n_channels"):
if hasattr(raster_obj, key):
v = getattr(raster_obj, key)
if isinstance(v, (int, float, str, bool)) or v is None:
artifacts[key] = v
except Exception:
return {"note": "artifact extraction failed"}
return artifacts
# ---------------------------------------------------------------------
# MCSData class
# ---------------------------------------------------------------------
[docs]
class MCSData(EData):
"""
MCS recording handler.
Parameters
----------
fname : str
Path to the `.h5` file.
fsample : float, optional
Sampling frequency (Hz). If None, read from the recording.
load_recording : bool, default=True
Load the analog recording stream into SpikeInterface.
load_digital : bool, default=False
Load a digital channel from the HDF5 file.
generate_probe : bool, default=False
Generate a probe object from probe_data (if you use this in your pipeline).
probe_data : dict, optional
Probe configuration.
"""
def __init__(
self,
fname: str,
fsample: Optional[float] = None,
load_recording: bool = True,
load_digital: bool = False,
generate_probe: bool = False,
probe_data: Optional[dict] = None,
) -> None:
if not os.path.exists(fname):
raise FileNotFoundError(f"File {fname} does not exist.")
# Keep these as before (public/state)
self.fname: str = fname
self.fsample: Optional[float] = fsample
self.load_digital: bool = load_digital
# Core data
self.recording = None
self.traces: Optional[np.ndarray] = None
self.peaks = None
# Channel/probe metadata (historical public attributes)
self.ch_ids = None
self.electrode_labels = None
self.mask: Optional[np.ndarray] = None
self.probe = None
self.probe_positions = None
self.probe_contact_shape = None
self.probe_shape_params = None
self.probe_ndims = None
# Time/selection masks
self.time_vector: Optional[np.ndarray] = None
self.temporal_mask: Optional[np.ndarray] = None
self.excluded_intervals: List[Tuple[float, float]] = []
# Digital/triggers
self.digital_recording = None
self.triggers = None
# State
self.artifact_removal_status: bool = False
# Load recording first (to preserve original behaviour/attributes)
if load_recording:
self._load_recording()
self.time_vector = np.arange(self.recording.get_total_samples()) / float(self.fsample)
self.ch_ids = self.recording.channel_ids
self.electrode_labels = self.recording.get_property("electrode_labels")
self.mask = np.ones(self.recording.get_num_channels(), dtype=bool)
self.temporal_mask = np.ones_like(self.time_vector, dtype=bool)
if generate_probe:
self._generate_probe(probe_data)
# Now initialize EData/AData with the loaded channel axis (minimal change)
if load_recording and self.recording is not None:
rec = self.recording # preserve handle; EData sets self.recording = None
super().__init__(
fsample=float(self.fsample) if self.fsample is not None else 0.0,
channel_ids=list(self.ch_ids),
electrode_labels=None if self.electrode_labels is None else list(self.electrode_labels),
mask=None if self.mask is None else list(self.mask.astype(bool)),
fname=fname,
recording_system="MCS",
)
self.recording = rec # restore
# Keep historical attributes aligned to base storage
# (AData defines channel_ids/electrode_labels/mask; we keep old names too)
self.ch_ids = self.channel_ids
self.electrode_labels = self.electrode_labels # already set by base
self.mask = self.mask # already set by base
else:
super().__init__(
fsample=float(self.fsample) if self.fsample is not None else 0.0,
channel_ids=[],
electrode_labels=None,
mask=None,
fname=fname,
recording_system="MCS",
)
# History:
# - AData already has `self.history: list[dict]` (lightweight)
# - Here we keep a rich ProcessingHistory for export_history_json
if load_recording and self.recording is not None:
ds = DatasetInfo.from_path(
fname=self.fname,
sampling_frequency=float(self.fsample) if self.fsample is not None else None,
stream_id=1,
channel_ids=self.channel_ids.tolist() if getattr(self, "channel_ids", None) is not None else None,
electrode_labels=self.electrode_labels.tolist() if getattr(self, "electrode_labels", None) is not None else None,
)
self.processing_history = ProcessingHistory(dataset=ds)
self.processing_history.set_state_getter(self._history_snapshot)
else:
self.processing_history = None
# -----------------------------------------------------------------
# History handling
# -----------------------------------------------------------------
def _history_snapshot(self) -> Dict[str, Any]:
"""
Create a JSON-friendly snapshot of relevant processing state (C).
Returns
-------
dict
Snapshot dict suitable for hashing and JSON serialization.
"""
snap: Dict[str, Any] = {
"fname": self.fname,
"sampling_frequency": float(self.fsample) if self.fsample is not None else None,
}
# Channel mask summary + compact representation (store list of kept channel ids)
if self.mask is not None and self.ch_ids is not None:
kept = list(np.asarray(self.ch_ids)[np.asarray(self.mask, dtype=bool)])
snap["mask_n_kept"] = int(len(kept))
snap["mask_kept_channel_ids"] = kept
# Temporal exclusions: store intervals list rather than full per-sample boolean
snap["excluded_intervals"] = list(self.excluded_intervals) if hasattr(self, "excluded_intervals") else []
snap["excluded_intervals_n"] = int(len(snap["excluded_intervals"]))
# Trigger summary (store the slots, not the digital signal)
if self.triggers is not None and hasattr(self.triggers, "slots"):
snap["triggers"] = [
{
"start": float(s.start),
"end": float(s.end),
"id": getattr(s, "ID", None),
"blank": getattr(s, "blank", None),
}
for s in self.triggers.slots
]
snap["triggers_n"] = int(len(snap["triggers"]))
return snap
[docs]
def export_history_json(self, path: str, indent: int = 2) -> None:
"""
Export processing history to a JSON file.
Parameters
----------
path : str
Output path.
indent : int, default=2
JSON indentation.
"""
if self.processing_history is None:
raise ValueError("No processing history available to export.")
self.processing_history.save_json(path, indent=indent)
# -----------------------------------------------------------------
# IO / setup
# -----------------------------------------------------------------
def _load_recording(self) -> None:
"""
Load the MCS recording via SpikeInterface and optionally load digital signal.
Notes
-----
- Uses stream_id=1.
- Renames channels based on electrode_labels to "Ch{label}".
"""
try:
self.recording = se.read_mcsh5(self.fname, stream_id=1)
except Exception as exc:
print(f"Error loading recording: {exc}")
sys.exit(1)
if self.load_digital:
try:
with h5py.File(self.fname, "r") as f:
stream = f["Data/Recording_0/AnalogStream/Stream_0/ChannelData"]
self.digital_recording = stream[0]
except Exception as exc:
print(f"Error loading digital recording: {exc}")
electrode_labels = self.recording.get_property("electrode_labels")
self.recording = self.recording.rename_channels([f"Ch{lab}" for lab in electrode_labels])
if self.fsample is None:
self.fsample = self.recording.get_sampling_frequency()
def _serialize_triggers(self) -> Optional[List[Dict[str, Any]]]:
if self.triggers is None:
return None
out: List[Dict[str, Any]] = []
for slot in getattr(self.triggers, "slots", []):
out.append(
{
"start": float(getattr(slot, "start", 0.0)),
"end": float(getattr(slot, "end", 0.0)),
"ID": getattr(slot, "ID", None),
"blank": getattr(slot, "blank", None),
}
)
return out
def _generate_probe(self, probe_data: Optional[dict]) -> None:
"""
Generate a probe object from probe_data.
Parameters
----------
probe_data : dict, optional
Probe configuration dictionary.
"""
# Keep your previous import style to avoid circular imports
from .mea_probe import MEAProbe
if probe_data is None:
raise ValueError("probe_data must be provided when generate_probe=True.")
self.probe_positions = probe_data.get("positions", None)
self.probe_contact_shape = probe_data.get("contact_shape", "circle")
self.probe_shape_params = probe_data.get("shape_params", {"radius": 5})
self.probe_ndims = probe_data.get("ndims", 2)
self.probe = MEAProbe(
positions=self.probe_positions,
contact_shape=self.probe_contact_shape,
shape_params=self.probe_shape_params,
ndims=self.probe_ndims,
)
# -----------------------------------------------------------------
# Backward-compatible wrappers for AData methods
# -----------------------------------------------------------------
[docs]
def set_mask(self, mask) -> None:
return super().set_mask(mask)
[docs]
def save_mask_and_labels(self, fname: str, csv_format: bool = False) -> int:
return super().save_mask_and_labels(fname=fname, csv_format=csv_format)
[docs]
def load_mask_and_labels(self, fname: str) -> int:
return super().load_mask_and_labels(fname=fname)
# -----------------------------------------------------------------
# Processing methods (kept with same signatures/defaults)
# -----------------------------------------------------------------
@tracked_operation("apply_filter")
def apply_filter(self, bandpass: Tuple[float, float] = (300, 3000), btype: str = "bandpass") -> int:
"""
Apply a SpikeInterface filter to the recording.
Parameters
----------
bandpass : array-like or float
For 'bandpass', provide [low_freq, high_freq].
For 'highpass', provide a single cutoff or the format expected by SpikeInterface.
btype : {'bandpass', 'highpass'}, default='bandpass'
Filter type.
Returns
-------
recording : spikeinterface.BaseRecording
The filtered recording.
Raises
------
ValueError
If recording is not loaded or `bandpass` is not provided.
Notes
-----
The original code had a check against the builtin `filter`; this cleanup
removes that non-functional check without changing behavior of the actual filtering.
"""
if self.recording is None:
raise ValueError("Recording not loaded.")
if bandpass is None:
raise ValueError("Bandpass frequencies must be provided.")
self.recording = pre.filter(self.recording, band=bandpass, btype=btype)
return self.recording
[docs]
def get_traces(
self,
tstart: float = 0,
tstop: float = 10,
channel_ids: Optional[List[int]] = None,
return_in_uV: bool = True,
) -> np.ndarray:
"""
Retrieve traces from the recording.
Parameters
----------
tstart : float, default=0
Start time in seconds.
tstop : float, default=10
Stop time in seconds.
channel_ids : list, optional
Subset of channel ids to extract.
return_in_uV : bool, default=True
Convert to microvolts.
Returns
-------
np.ndarray
Traces array of shape (n_samples, n_channels).
"""
if self.recording is None:
raise ValueError("Recording not loaded.")
if self.time_vector is None:
raise ValueError("Time vector not initialized.")
if tstart is None:
tstart = float(self.time_vector[0])
if tstop is None:
tstop = float(self.time_vector[-1])
if channel_ids is None:
channel_ids = self.ch_ids[self.mask]
start_frame = int(tstart * float(self.fsample))
end_frame = int(tstop * float(self.fsample))
traces = self.recording.get_traces(
start_frame=start_frame,
end_frame=end_frame,
channel_ids=channel_ids,
return_in_uV=return_in_uV,
)
return traces
[docs]
def plot_traces_in_grid(
self,
tmin: float = 0,
tmax: float = 10,
n_subsample: Optional[int] = None,
show: bool = True,
) -> int:
"""
Plot channel traces in a grid.
Parameters
----------
tmin : float, default=0
Start time (s).
tmax : float, default=10
Stop time (s).
n_subsample: Optional[int]
1/n_subsample factor, represents the number of time samples skipped in plotting.
show : bool, default=True
Show plot.
Returns
-------
int
Always returns 1 (backward compatible).
"""
if self.recording is None:
raise ValueError("Recording not loaded.")
if self.fsample is None:
raise ValueError("Sampling frequency not initialized.")
if self.ch_ids is None or self.electrode_labels is None:
raise ValueError("Channel metadata not initialized.")
_, axes = plt.subplots(8, 8, figsize=(8, 8))
plt.subplots_adjust(wspace=0.1, hspace=0.1)
for ax in axes.flat:
ax.axis("off")
n = len(self.ch_ids)
lines = [None] * n
for i, (ch, lab) in enumerate(zip(self.ch_ids, self.electrode_labels)):
lab = int(lab)
col = lab // 10 - 1
row = lab % 10 - 1
ax = axes[row, col]
traces = self.recording.get_traces(
start_frame=int(tmin * float(self.fsample)),
end_frame=int(tmax * float(self.fsample)),
channel_ids=[ch],
return_in_uV=True,
)
local_time_vector = np.arange(traces.shape[0]) / float(self.fsample) + tmin
# option for subsampling
if n_subsample is None:
(ln,) = ax.plot(local_time_vector, traces, lw=0.8, color="C0")
lines[i] = ln
ax.set_title(lab, fontsize=6)
ax.set_xlim(tmin, tmax)
ax.set_ylim(-50, 50)
ax.axis("off")
else:
# check if n_subsample is >1
if n_subsample > 1:
local_time_vector_sub = local_time_vector[::n_subsample]
traces_sub = traces[::n_subsample]
(ln,) = ax.plot(local_time_vector_sub, traces_sub, lw=0.8, color="C0")
lines[i] = ln
ax.set_title(lab, fontsize=6)
ax.set_xlim(tmin, tmax)
ax.set_ylim(-50, 50)
ax.axis("off")
else:
raise ValueError("n_subsample must be greater than 1.")
if show:
plt.show()
return 1
@tracked_operation("detect_spikes")
def detect_spikes(
self,
recording: Optional[si.BaseRecording] = None,
method: str = "by_channel",
peak_sign: str = "neg",
detect_threshold: float = 5,
exclude_sweep_ms: float = 1.0,
noise_levels=None,
detect_noise_levels: Optional[bool] = None,
job_kwargs = None
) -> int:
"""
Detect spikes (peaks) in the recording.
Parameters
----------
recording : spikeinterface.BaseRecording, optional
Recording object to use for detect spikes. If None, uses `self.recording`. This allows using an artifact-removed recording if desired.
method : str, default='by_channel'
Peak detection method passed to `detect_peaks`.
peak_sign : str, default='neg'
'neg' or 'pos' depending on spike polarity.
detect_threshold : float, default=5
Detection threshold.
exclude_sweep_ms : float, default=0.2
Exclusion window in ms.
noise_levels: array or None
array of noise levels array precomputed in uV.
detect_noise_levels: bool,
if no array of noise level is provided, they can be computed, if set to True.
job_kwargs
if not None, explicitly uses parallel computing, e.g. {'n_jobs': n}.
Returns
-------
int
Always returns 1 (kept for backward compatibility).
"""
if recording is None:
recording = self.recording
else:
recording = recording # use provided recording (e.g., artifact-removed)
method_kwargs = {
'peak_sign': peak_sign,
'detect_threshold': detect_threshold,
'exclude_sweep_ms': exclude_sweep_ms,
}
if noise_levels is not None:
method_kwargs['noise_levels'] = noise_levels
if not isinstance(noise_levels, np.ndarray):
raise ValueError("noise_levels must be a numpy array.")
self.peaks = detect_peaks(
recording=recording,
method=method,
method_kwargs = method_kwargs,
job_kwargs = job_kwargs
)
else:
if detect_noise_levels is None:
self.peaks = detect_peaks(
recording=recording,
method=method,
method_kwargs = method_kwargs,
job_kwargs = job_kwargs
)
else:
noise_levels = si.get_noise_levels(recording, return_in_uV =False)
method_kwargs['noise_levels'] = noise_levels
self.peaks = detect_peaks(
recording=recording,
method=method,
method_kwargs = method_kwargs,
job_kwargs = job_kwargs
)
return 1
[docs]
def plot_raster(self, ax) -> int:
"""
Plot a raster of detected spikes into an existing axes.
Parameters
----------
ax : matplotlib.axes.Axes
Target axes.
Returns
-------
int
Always returns 1 (backward compatible).
"""
if self.peaks is None:
raise ValueError("Spikes not detected.")
if self.fsample is None:
raise ValueError("Sampling frequency not initialized.")
ch_ids = np.asarray(self.ch_ids)
y = ch_ids[self.peaks["channel_index"]]
y_vals = np.array([int(ch.replace("Ch", "")) for ch in y])
ax.scatter(self.peaks["sample_index"] / float(self.fsample), y_vals, s=1)
ax.set_yticks(y_vals)
ax.set_yticklabels(y)
# peaks_sc = np.column_stack((self.peaks["sample_index"], self.peaks["channel_index"]))
# ax.scatter(peaks_sc[:, 0] / float(self.fsample), peaks_sc[:, 1], s=1)
# ax.set_xlabel("Sample Index")
# ax.set_ylabel("Channel Index")
# ax.set_title("Spike Raster Plot")
return 1
[docs]
def get_probe(self):
"""
Get the probe object.
Returns
-------
MEAProbe
The probe object if generated, else None.
"""
print(self.recording.get_probe())
return self.recording.get_probe() if self.recording is not None else None
@tracked_operation("choose_mask")
def choose_mask(self, tmin: float = 0, tmax: float = 10, show: bool = True) -> None:
"""
Open a GUI to select channels to include (updates `self.mask`).
Parameters
----------
tmin : float, default=0
Start time (s) for trace preview.
tmax : float, default=10
Stop time (s) for trace preview.
show : bool, default=True
Show plot.
Notes
-----
This method keeps the original 8x8 grid layout and electrode label mapping:
- col = lab // 10 - 1
- row = lab % 10 - 1
"""
if self.recording is None:
raise ValueError("Recording not loaded.")
if self.fsample is None:
raise ValueError("Sampling frequency not initialized.")
if self.ch_ids is None or self.electrode_labels is None:
raise ValueError("Channel metadata not initialized.")
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
plt.subplots_adjust(wspace=0.1, hspace=0.1)
for ax in axes.flat:
ax.axis("off")
n = len(self.ch_ids)
mask = np.ones(n, dtype=bool)
lines = [None] * n
checks = [None] * n
def make_toggle(i: int, checks):
def _toggle(_label):
mask[i] = not mask[i]
ln = lines[i]
if ln is not None:
ln.set_color("C0" if mask[i] else "0.7")
cb = checks[i]
if cb is not None:
try:
cb.rectangles[0].set_facecolor("white" if mask[i] else "0.9")
except Exception:
pass
fig.canvas.draw_idle()
return _toggle
for i, (ch, lab) in enumerate(zip(self.ch_ids, self.electrode_labels)):
lab = int(lab)
col = lab // 10 - 1
row = lab % 10 - 1
ax = axes[row, col]
traces = self.recording.get_traces(
start_frame=int(tmin * float(self.fsample)),
end_frame=int(tmax * float(self.fsample)),
channel_ids=[ch],
return_in_uV=True,
)
local_time_vector = np.arange(traces.shape[0]) / float(self.fsample) + tmin
(ln,) = ax.plot(local_time_vector, traces, lw=0.8, color="C0")
lines[i] = ln
ax.set_title(lab, fontsize=6)
ax.set_xlim(tmin, tmax)
ax.set_ylim(-50, 50)
ax.axis("off")
bbox = ax.get_position()
w = bbox.width * 0.18
h = bbox.height * 0.18
x0 = bbox.x0 + bbox.width * 0.02
y0 = bbox.y1 - h - bbox.height * 0.02
cax = fig.add_axes([x0, y0, w, h])
cax.set_xticks([])
cax.set_yticks([])
for spine in cax.spines.values():
spine.set_visible(False)
cb = CheckButtons(cax, labels=[""], actives=[True])
for txt in getattr(cb, "labels", []):
txt.set_visible(False)
line_groups = getattr(cb, "lines", None) or getattr(cb, "lines_", None)
if line_groups is not None:
for pair in line_groups:
try:
pair[0].set_linewidth(1.0)
pair[1].set_linewidth(1.0)
except Exception:
pass
cb.on_clicked(make_toggle(i, checks))
checks[i] = cb
if show:
plt.show()
self.mask = mask
@tracked_operation("blank_period")
def blank_period(self, tstart: float, tstop: float) -> None:
"""
Exclude (blank) a time window from analysis.
Parameters
----------
tstart : float
Start time (s).
tstop : float
Stop time (s).
"""
if self.time_vector is None or self.temporal_mask is None:
raise ValueError("Time vector not initialized.")
a = float(min(tstart, tstop))
b = float(max(tstart, tstop))
self.excluded_intervals.append((a, b))
self.temporal_mask &= ~((self.time_vector >= a) & (self.time_vector <= b))
[docs]
def convert_digital(self) -> np.ndarray:
"""
Convert raw digital recording to a small integer state representation.
Returns
-------
np.ndarray
Converted digital signal as int32.
Notes
-----
Preserves the original transformation:
- anchor by first sample
- log2(abs(x - a + 1))
- values > 2 are set to 0
"""
a = self.digital_recording[0]
self.digital_recording = np.log2(np.abs(self.digital_recording - a + 1))
# find the most common peak value in the entire digital recording and set all other non-zero values to zero
data = np.round(self.digital_recording).astype(int) # first round, then truncate
most_common_value = np.bincount(data)[1:].argmax() + 1 # find the most common nonzero value (+1 to find the int not the position)
self.digital_recording[data != most_common_value] = 0 # zero all others
self.digital_recording = np.asarray(self.digital_recording, dtype=np.int32)
return self.digital_recording
[docs]
def detect_digital_rising_edge(self) -> list[int]:
"""
Detect rising edges (index positions) in the converted digital signal.
Returns
-------
list of int
Sample indices where digital_recording[i] > digital_recording[i-1].
"""
edges: list[int] = []
for i in range(1, len(self.digital_recording)):
if self.digital_recording[i] > self.digital_recording[i - 1]:
edges.append(i)
return edges
[docs]
def detect_digital_falling_edge(self) -> list[int]:
"""
Detect falling edges (index positions) in the converted digital signal.
Returns
-------
list of int
Sample indices where digital_recording[i] < digital_recording[i-1].
"""
edges: list[int] = []
for i in range(1, len(self.digital_recording)):
if self.digital_recording[i] < self.digital_recording[i - 1]:
edges.append(i)
return edges
@tracked_operation("get_triggers")
def get_triggers(
self,
method: str = "artifact",
tstart: Optional[float] = None,
tstop: Optional[float] = None,
interpretor: Optional[Callable] = None,
dt_after_trigger: Optional[float] = None,
artifact_threshold: Optional[float] = None,
mean_noise_level: Optional[float] = None,
refractory_trigger_period: Optional[float] = None,
stim_on_time: Optional[float] = None,
moving_avg_window: Optional[float] = 0.005
):
"""
Build `Triggers` from the digital recording in a given time window.
Parameters
----------
method : str, default='artifact'
Method to obtain triggers.
- 'artifact': detect artifact blocks from the analog recording using a threshold.
- 'first_passage': detect the first threshold crossing of each event and generate a fixed-duration slot.
- 'digital_trigger': use the digital recording directly, optionally with a custom `interpretor`.
tstart : float, optional
Start time in seconds. Defaults to the start of the recording.
tstop : float, optional
Stop time in seconds. Defaults to the end of the recording.
interpretor : callable, optional
Custom function to interpret the (windowed) digital signal into trigger slots.
It must accept `(digital_recording, triggers, fsample)` or
`(digital_recording, triggers, fsample, dt_after_trigger)`.
Trigger times created by the custom interpreter are assumed to be
relative to `tstart`; `get_triggers` will shift them to absolute time.
dt_after_trigger : float, optional
Extra argument forwarded to `interpretor` when provided.
artifact_threshold : float, optional
Threshold multiplier used by `artifact` and `first_passage` methods.
mean_noise_level : float, optional
Mean noise level used by `artifact` and `first_passage` methods.
refractory_trigger_period : float, optional
Required for `first_passage`. Minimum time in seconds between consecutive triggers.
stim_on_time : float, optional
Required for `first_passage`. Duration in seconds for the generated trigger slot.
moving_avg_window : float, optional
Time in ms of the moving average window when artifact or first passage methods are chosen.
Returns
-------
Triggers
Trigger container with the detected interval slots.
Raises
------
TypeError
If `method` is not a string.
ValueError
If required arguments for the selected method are missing, invalid,
or if the requested time window is outside the recording range.
"""
if self.time_vector is None:
raise ValueError("Time vector not initialized.")
if self.digital_recording is None:
raise ValueError(
"Digital recording not loaded. Set load_digital=True when constructing MCSData."
)
if self.fsample is None:
raise ValueError("Sampling frequency not initialized.")
if tstart is None:
tstart = float(self.time_vector[0])
if tstop is None:
tstop = float(self.time_vector[-1])
if not isinstance(method, str):
raise TypeError("method must be a string.")
method = method.strip().lower()
if tstart < float(self.time_vector[0]) or tstop > float(self.time_vector[-1]):
raise ValueError(
"tstart and tstop must fall within the available recording interval."
)
if tstart >= tstop:
raise ValueError("tstart must be strictly less than tstop.")
if method not in {"artifact", "first_passage", "digital_trigger"}:
raise ValueError(
"Unknown method. Supported values are 'artifact', 'first_passage', and 'digital_trigger'."
)
if method in {"artifact", "first_passage"}:
if artifact_threshold is None:
raise ValueError(
"artifact_threshold is required for method='artifact' or method='first_passage'."
)
if mean_noise_level is None:
raise ValueError(
"mean_noise_level is required for method='artifact' or method='first_passage'."
)
if artifact_threshold <= 0 or mean_noise_level <= 0:
raise ValueError(
"artifact_threshold and mean_noise_level must be positive values."
)
if method == "first_passage":
if refractory_trigger_period is None:
raise ValueError(
"refractory_trigger_period is required for method='first_passage'."
)
if stim_on_time is None:
raise ValueError("stim_on_time is required for method='first_passage'.")
if refractory_trigger_period <= 0:
raise ValueError("refractory_trigger_period must be positive.")
if stim_on_time <= 0:
raise ValueError("stim_on_time must be positive.")
if method == "digital_trigger" and interpretor is not None and not callable(interpretor):
raise ValueError(
"interpretor must be a callable function that defines how to interpret the digital signal into triggers."
)
from .Triggers import Triggers # avoid circular import
self.triggers = Triggers(slots=[])
start_frame = int(tstart * float(self.fsample))
end_frame = int(tstop * float(self.fsample))
if start_frame < 0 or end_frame <= start_frame:
raise ValueError("Invalid start/stop sample frame range for trigger extraction.")
if method in {"artifact", "first_passage"}:
traces = self.recording.get_traces(
start_frame=start_frame,
end_frame=end_frame,
return_in_uV=False,
)
valid_channels = np.where(self.mask)[0]
if len(valid_channels) == 0:
raise ValueError("No valid channels found in mask for artifact detection.")
channel_index = int(valid_channels[0])
trace = np.abs(traces[:, channel_index])
n = len(trace)
if n == 0:
raise ValueError("No samples available in the selected recording window.")
rect_trace = trace * ((-1) ** np.arange(n))
window = max(1, int(float(self.fsample) * moving_avg_window))
if n < window:
raise ValueError(
"Selected recording window is too short for artifact detection."
)
moving_avg = np.convolve(np.abs(rect_trace), np.ones(window), "valid") / window
stim_on = np.flatnonzero(moving_avg > artifact_threshold * mean_noise_level)
stim_status = np.zeros_like(moving_avg, dtype=np.int64)
stim_status[stim_on] = 1
rising_edges = np.flatnonzero(np.diff(stim_status) == 1) + window
if method == "artifact":
print(method)
falling_edges = np.flatnonzero(np.diff(stim_status) == -1) + window
for start_idx, end_idx in zip(rising_edges, falling_edges):
self.triggers.add_interval_slot(
start=float((start_frame + start_idx) / float(self.fsample)),
end=float((start_frame + end_idx) / float(self.fsample)),
)
else:
last_trigger_time = -np.inf
for start_idx in rising_edges:
start_time = float((start_frame + start_idx) / float(self.fsample))
if start_time - last_trigger_time < refractory_trigger_period:
continue
self.triggers.add_interval_slot(
start=start_time,
end=float(start_time + stim_on_time),
)
last_trigger_time = start_time
else:
window_mask = (self.time_vector >= tstart) & (self.time_vector <= tstop)
digital_window = np.asarray(self.digital_recording[window_mask], dtype=np.float64)
if digital_window.size == 0:
raise ValueError("No digital samples found in the requested time window.")
backup_digital = self.digital_recording
self.digital_recording = digital_window
self.convert_digital()
if interpretor is None:
rising_edges = self.detect_digital_rising_edge()
falling_edges = self.detect_digital_falling_edge()
self.digital_recording = backup_digital
for start_idx, end_idx in zip(rising_edges, falling_edges):
self.triggers.add_interval_slot(
start=float((start_frame + start_idx) / float(self.fsample)),
end=float((start_frame + end_idx) / float(self.fsample)),
)
else:
digital_signal = np.asarray(self.digital_recording, dtype=np.int32)
self.digital_recording = backup_digital
if dt_after_trigger is None:
interpretor(digital_signal, self.triggers, float(self.fsample))
else:
interpretor(digital_signal, self.triggers, float(self.fsample), dt_after_trigger)
offset = float(start_frame) / float(self.fsample)
for slot in self.triggers.slots:
slot.start += offset
slot.end += offset
self.triggers.sort_slots()
return self.triggers
@tracked_operation("remove_artifacts_from_trigger")
def remove_artifacts_from_trigger(
self,
ms_before: float = 0.1,
ms_after: float = 0.4,
mode: str = "zeros",
):
"""
Remove stimulation artifacts around triggers using SpikeInterface.
Parameters
----------
ms_before : float, default=0.1
Time (ms) before each trigger to remove.
ms_after : float, default=0.4
Time (ms) after each trigger to remove.
mode : str, default='zeros'
Artifact replacement mode (as supported by `pre.remove_artifacts`).
Returns
-------
recording_clean : spikeinterface.BaseRecording
Recording after artifact removal.
Raises
------
ValueError
If artifact removal has already been performed or triggers are missing.
"""
if self.artifact_removal_status:
raise ValueError("Artifact removal already performed.")
if self.triggers is None:
raise ValueError("Triggers not defined. Run get_triggers() first.")
if self.fsample is None:
raise ValueError("Sampling frequency not initialized.")
if self.recording is None:
raise ValueError("Recording not loaded.")
list_triggers = [int(slot.start * float(self.fsample)) for slot in self.triggers.slots]
self.recording = pre.remove_artifacts(
self.recording,
list_triggers=list_triggers,
ms_before=ms_before,
ms_after=ms_after,
mode=mode,
)
self.artifact_removal_status = True
return 1
@tracked_operation("get_raster", include_result_artifacts=_raster_artifacts)
def get_raster(self,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
include_amplitudes: bool = False,
include_triggers: bool = False):
"""
Export a Raster object using detected peaks and current selection.
Parameters
----------
tstart : float, optional
Start time in seconds. Defaults to start of recording.
tstop : float, optional
Stop time in seconds. Defaults to end of recording.
include_amplitudes : bool, default=True
If True, include spike amplitudes in the raster in uV.
include_triggers : bool, default=True
If True, include triggers from the recording in the raster.
Returns
-------
Raster
Raster object with amplitudes and triggers fields populated.
"""
if self.peaks is None:
raise ValueError("Spikes not detected.")
if self.time_vector is None or self.temporal_mask is None:
raise ValueError("Time vector / temporal mask not initialized.")
if self.fsample is None:
raise ValueError("Sampling frequency not initialized.")
if self.mask is None:
raise ValueError("Channel mask not initialized.")
if tstart is None:
tstart = float(self.time_vector[0])
if tstop is None:
tstop = float(self.time_vector[-1])
from .raster import Raster # avoid circular import
channel_ids = np.asarray(self.ch_ids)
kept_channel_ids = channel_ids[self.mask]
kept_channel_indices = np.arange(len(channel_ids))[self.mask]
r = Raster.empty(channels=kept_channel_ids)
amplitudes: Dict[Union[int, str], np.ndarray] = {}
for orig_idx, ch in zip(kept_channel_indices, kept_channel_ids):
ch_peaks = self.peaks[self.peaks["channel_index"] == orig_idx]
this_channel_times = self.peaks["sample_index"][self.peaks["channel_index"] == orig_idx] / float(self.fsample)
idx = (this_channel_times * float(self.fsample)).astype(int)
idx = np.clip(idx, 0, len(self.temporal_mask) - 1)
keep_spikes = (
(this_channel_times >= tstart)
& (this_channel_times <= tstop)
& self.temporal_mask[idx]
)
r.insert_timestamparray(ch, this_channel_times[keep_spikes], assume_sorted=True)
if include_amplitudes:
kept_samples = ch_peaks["sample_index"].astype(int)[keep_spikes]
if kept_samples.size:
start = int(max(0, kept_samples.min()))
end = int(kept_samples.max()) + 1
trace = self.recording.get_traces(
start_frame=start,
end_frame=end,
channel_ids=[ch],
return_in_uV=True,
).flatten()
amplitudes[ch] = trace[kept_samples - start]
else:
amplitudes[ch] = np.asarray([], dtype=np.float64)
if include_amplitudes:
r.amplitudes = amplitudes
if include_triggers:
r.triggers = self._serialize_triggers()
# Attach provenance snapshot directly to the raster
if getattr(self, "history", None) is not None:
try:
r.provenance = self.history.to_dict()
except Exception:
pass
return r