from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Sequence
import numpy as np
from scipy.signal import find_peaks
from tad.raster import Raster, ChannelId
from tad.metrics.utils import _select_channels, _infer_window
[docs]
@dataclass(frozen=True)
class Burst:
"""
Single burst description.
Parameters
----------
start
Burst start time (time of first spike in burst).
end
Burst end time (time of last spike in burst).
n_spikes
Number of spikes in the burst.
duration
Burst duration defined as end - start.
intra_rate_hz
Intra-burst firing rate (Hz), defined here as (n_spikes - 1) / duration
if duration > 0, else np.inf.
"""
start: float
end: float
n_spikes: int
duration: float
intra_rate_hz: float
[docs]
@dataclass(frozen=True)
class LogISIThresholdDiagnostics:
"""
Diagnostics for adaptive ISI threshold selection using the log-ISIH.
Parameters
----------
isi_th
Chosen ISI threshold (seconds).
centers_log10
Histogram bin centers in log10(ISI) domain.
hist
Raw histogram values (counts or density depending on call).
hist_smooth
Smoothed histogram used for peak/valley detection.
peak1_idx
Index of the left peak (short-ISI regime) in `centers_log10`, or -1.
peak2_idx
Index of the right peak (long-ISI regime) in `centers_log10`, or -1.
valley_idx
Index of the valley between peaks used as threshold, or -1 when fallback used.
method
Selection path, e.g. "valley_between_peaks" or "fallback_quantile_*".
"""
isi_th: float
centers_log10: np.ndarray
hist: np.ndarray
hist_smooth: np.ndarray
peak1_idx: int
peak2_idx: int
valley_idx: int
method: str
[docs]
@dataclass(frozen=True)
class BurstChannelResult:
"""
Bursts for one channel.
Parameters
----------
channel
Channel id.
isi_th
ISI threshold used for this channel.
bursts
List of detected bursts.
diagnostics
Threshold-selection diagnostics if method="logisih", otherwise None.
If threshold_scope="pooled", this may be the same object for all channels.
"""
channel: ChannelId
isi_th: float
bursts: List[Burst]
diagnostics: Optional[LogISIThresholdDiagnostics] = None
[docs]
@dataclass(frozen=True)
class BurstDetectionResult:
"""
Burst detection result across channels.
Parameters
----------
per_channel
Mapping channel -> BurstChannelResult.
method
A human-readable description of the method used.
tstart, tstop
Time window used for detection.
channels
Channel IDs included (order).
"""
per_channel: Dict[ChannelId, BurstChannelResult]
method: str
tstart: float
tstop: float
channels: List[ChannelId]
def _detect_bursts_from_times(
times: np.ndarray,
*,
isi_th: float,
min_spikes: int = 3,
) -> List[Burst]:
"""
Detect bursts from sorted spike times using a fixed ISI threshold rule.
A burst is a maximal consecutive sequence of spikes such that each consecutive
inter-spike interval (ISI) satisfies ISI <= isi_th. The sequence is accepted as
a burst if it contains at least `min_spikes` spikes.
Parameters
----------
times
Sorted spike times, shape (N,).
isi_th
ISI threshold (seconds).
min_spikes
Minimum number of spikes required to accept a burst.
Returns
-------
list of Burst
Detected bursts.
"""
t = np.asarray(times, dtype=np.float64).ravel()
if t.size < min_spikes:
return []
isi = np.diff(t)
if isi.size == 0:
return []
link = isi <= float(isi_th) # True means spike i and i+1 are within the same burst
bursts: List[Burst] = []
i = 0
n = t.size
while i < n - 1:
if not link[i]:
i += 1
continue
start_idx = i
j = i
while j < n - 1 and link[j]:
j += 1
end_idx = j
nsp = end_idx - start_idx + 1
if nsp >= min_spikes:
start = float(t[start_idx])
end = float(t[end_idx])
dur = float(end - start)
intra = float((nsp - 1) / dur) if dur > 0.0 else float("inf")
bursts.append(Burst(start=start, end=end, n_spikes=int(nsp), duration=dur, intra_rate_hz=intra))
i = end_idx
return bursts
[docs]
def choose_isi_threshold_logisih(
isi_values: Optional[np.ndarray] = None,
r: Optional[Raster] = None,
channel: Optional[ChannelId] = None,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
inclusive_stop: bool = False,
*,
bins: str | int = "pasquale",
smooth_window: str | int = "from_bins",
density: bool = False,
min_isi: float = 1e-6,
fallback: float = 0.1,
) -> LogISIThresholdDiagnostics:
"""
Choose an ISI threshold from the log10(ISI) histogram (log-ISIH).
This is an adaptive threshold selection strategy commonly used for burst detection
in MEA analyses: short ISIs (within bursts) and long ISIs (between bursts) can
create a structured distribution in log10(ISI). When two regimes are detectable,
the threshold is taken as the valley between the two main peaks.
Algorithm (robust, NumPy-only):
1) Compute y = log10(ISI) for ISI > 0.
2) Compute histogram of y with `bins` bins.
3) Smooth histogram with moving average window `smooth_window` using 20% of bins as window.
4) Detect local maxima (peaks) on smoothed histogram.
5) If two sufficiently separated peaks exist, pick the valley (minimum) between them
and set ISI_th = 10**(center_at_valley).
6) Otherwise fallback to a conservative quantile of the ISI distribution.
Parameters
----------
isi_values
1D array of ISIs (>0).
bins
Number of bins in log10(ISI) space, or method from np.histogram.
smooth_window
Moving average smoothing window length, or method from number of bins.
density
If True, histogram is density; otherwise counts.
min_isi
Minimum ISI value before log transform (prevents log(0)).
fallback
ISI used if peak/valley detection fails.
Returns
-------
LogISIThresholdDiagnostics
Contains chosen ISI_th and intermediate arrays/indices for inspection.
"""
if isi_values is None and r is None:
raise ValueError("Pass isi_values or r.")
if isi_values is not None and r is not None:
raise ValueError("Pass only isi_values or r.")
if r is not None:
ch = channel
tstart_f, tstop_f = _infer_window(r, [ch], tstart, tstop)
arr = r.events[channel]
left = np.searchsorted(arr, tstart_f, side="left")
right = np.searchsorted(arr, tstop_f, side=("right" if inclusive_stop else "left"))
w = arr[left:right]
isi_values = np.diff(w).astype(np.float64, copy=False) if w.size >= 2 else np.asarray([], dtype=np.float64)
x = np.asarray(isi_values, dtype=np.float64).ravel()
x = x[np.isfinite(x) & (x > 0.0)]
if x.size == 0:
return LogISIThresholdDiagnostics(
isi_th=float("nan"),
centers_log10=np.asarray([], dtype=np.float64),
hist=np.asarray([], dtype=np.float64),
hist_smooth=np.asarray([], dtype=np.float64),
peak1_idx=-1,
peak2_idx=-1,
valley_idx=-1,
method="no_data",
)
x = np.maximum(x, float(min_isi))
y = np.log10(x)
if bins == "pasquale":
y_max = np.max(y)
y_min = np.min(y)
# set bin size to 0.1 in log10 space:
bins = int(np.ceil((y_max - y_min) / 0.1))
if bins < 1:
isi_th = float("nan")
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=None,
hist=None,
hist_smooth=None,
peak1_idx=-1,
peak2_idx=-1,
valley_idx=-1,
method="not_able_to_find_bins",
)
hist, edges = np.histogram(y, bins=bins, density=bool(density))
centers = 0.5 * (edges[:-1] + edges[1:])
h = hist.astype(np.float64, copy=False)
# Smooth (moving average)
if isinstance(smooth_window,str) and smooth_window == "from_bins":
w = max(1, int(0.20 * centers.size)) # 20% of bins
else:
w = int(smooth_window)
if w < 1:
w = 1
if w > 1:
kernel = np.ones(w, dtype=np.float64) / float(w)
h_s = np.convolve(h, kernel, mode="same")
else:
h_s = h.copy()
if h_s.size < 3:
isi_th = float("nan")
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=-1,
peak2_idx=-1,
valley_idx=-1,
method="not_able_to_find_bursts",
)
# Local maxima indices
peaks = find_peaks(h_s)[0]
import matplotlib.pyplot as plt
plt.plot(centers, h, label="hist", alpha=0.5)
plt.plot(centers, h_s, label="hist_smooth", alpha=0.8)
plt.scatter(centers[peaks], h_s[peaks], color="red", label="peaks")
plt.show()
isi_peaks = centers[peaks] if peaks.size > 0 else np.array([])
if isi_peaks.size == 0:
isi_th = float("nan")
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=-1,
peak2_idx=-1,
valley_idx=-1,
method="no_burst_regime_no_peaks",
)
if isi_peaks[0] > -1.0:
isi_th = float("nan")
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=int(peaks[0]) if peaks.size > 0 else -1,
peak2_idx=int(peaks[1]) if peaks.size > 1 else -1,
valley_idx=-1,
method="no_burst_regime_first_peak_high",
)
if isi_peaks[-1] < -1.0:
isi_th = float(fallback)
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=int(peaks[0]) if peaks.size > 0 else -1,
peak2_idx=int(peaks[1]) if peaks.size > 1 else -1,
valley_idx=-1,
method="fallback_returned",
)
if peaks.size < 2 and isi_peaks[0] < -1.0:
isi_th = float(fallback)
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=int(peaks[0]) if peaks.size == 1 else -1,
peak2_idx=-1,
valley_idx=-1,
method="fallback_returned",
)
# Find last peak with log10(isi) < -1.0 and maximum peak with log10(isi) > -1.0
intra_burst_peaks = isi_peaks[isi_peaks < -1.0]
idx_last_intra_burst_peak = peaks[int(len(intra_burst_peaks) - 1)]
inter_burst_peaks = isi_peaks[isi_peaks > -1.0]
argmax_inter = np.argmax(inter_burst_peaks)
idx_in_isi_peaks = np.where(isi_peaks > -1.0)[0][argmax_inter]
idx_first_inter_burst_peak = peaks[idx_in_isi_peaks]
peak1_idx, peak2_idx = (idx_last_intra_burst_peak, idx_first_inter_burst_peak)
if h_s[peak1_idx] < 0.05 * h_s[peak2_idx]:
isi_th = float("nan")
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=int(peak1_idx),
peak2_idx=int(peak2_idx),
valley_idx=-1,
method="no_burst_regime_small_first_peak",
)
seg = h_s[peak1_idx:peak2_idx + 1]
if seg.size < 3:
isi_th = float(fallback)
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=int(peak1_idx),
peak2_idx=int(peak2_idx),
valley_idx=-1,
method="fallback_segment_too_small",
)
valley_rel = int(np.argmin(seg))
valley_idx = int(peak1_idx + valley_rel)
isi_th = float(10.0 ** centers[valley_idx])
return LogISIThresholdDiagnostics(
isi_th=isi_th,
centers_log10=centers,
hist=h,
hist_smooth=h_s,
peak1_idx=int(peak1_idx),
peak2_idx=int(peak2_idx),
valley_idx=int(valley_idx),
method="valley_between_peaks",
)
[docs]
def detect_bursts(
r: Raster,
*,
method: Literal["fixed", "logisih", "fixed_from_logisih"] = "fixed",
isi_th: float | np.ndarray = 0.1,
min_spikes: int = 3,
channels: Optional[Sequence[ChannelId]] = None,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
inclusive_stop: bool = False,
threshold_scope: Literal["per_channel", "pooled"] = "per_channel",
logisih_bins: str | int = "auto",
logisih_smooth_window: str | int = "from_bins",
fallback: float = 0.1,
) -> BurstDetectionResult:
"""
Detect single-channel bursts in a Raster.
Three threshold-selection methods are supported:
1) method="fixed"
Uses the constant threshold `isi_th` for all channels.
2) method="logisih"
Derives the ISI threshold from the log10(ISI) histogram (log-ISIH) either:
- per channel (threshold_scope="per_channel"), or
- from pooled ISIs across selected channels (threshold_scope="pooled").
3) method="fixed_from_logisih"
Uses pre-calculated ISI thresholds from log-ISIH analysis.
For this method, `isi_th` must be an array with one threshold value per channel
in the same order as `channels`.
Burst definition (segmentation):
A burst is a maximal consecutive sequence of spikes such that every consecutive
inter-spike interval satisfies ISI <= ISI_th, and the sequence contains at least
`min_spikes` spikes.
Parameters
----------
r
Input raster.
method
Threshold selection method: "fixed", "logisih", or "fixed_from_logisih".
isi_th
ISI threshold(s) in seconds. For "fixed" and "logisih", a single float.
For "fixed_from_logisih", an array of floats with one per channel.
min_spikes
Minimum number of spikes required to accept a burst.
channels
Channels to include. If None, uses all channels in the raster.
tstart, tstop
Optional time window. If None, inferred from data across selected channels.
inclusive_stop
If True, include events exactly at tstop; otherwise window is [tstart, tstop).
threshold_scope
Only used when method="logisih":
- "per_channel": compute a separate ISI_th per channel from that channel's ISIs
- "pooled": compute a single ISI_th from pooled ISIs across channels and reuse it
logisih_bins
Number of bins in log10(ISI) histogram, or method for numpy.histogram ("auto", "fd", "doane", "sqrt" etc) (method="logisih").
logisih_smooth_window
Smoothing window length for histogram, or method to determine it from number of bins (method="logisih").
fallback
ISI threshold value used when log-ISIH peak/valley detection fails.
Returns
-------
BurstDetectionResult
Per-channel burst lists and metadata. For method="logisih", each channel result
includes a `diagnostics` object that can be plotted/inspected; when
threshold_scope="pooled", this diagnostic may be shared across channels.
Raises
------
ValueError
If parameters are inconsistent.
"""
if method not in ("fixed", "logisih", "fixed_from_logisih"):
raise ValueError("method must be 'fixed', 'logisih', or 'fixed_from_logisih'.")
if threshold_scope not in ("per_channel", "pooled"):
raise ValueError("threshold_scope must be 'per_channel' or 'pooled'.")
if min_spikes < 2:
raise ValueError("min_spikes must be >= 2.")
ch_list = _select_channels(r, channels)
tstart_f, tstop_f = _infer_window(r, ch_list, tstart, tstop)
# Validate isi_th for fixed_from_logisih method
if method == "fixed_from_logisih":
isi_th_array = np.asarray(isi_th)
if isi_th_array.ndim != 1 or isi_th_array.size != len(ch_list):
raise ValueError(
f"For method='fixed_from_logisih', isi_th must be a 1D array with "
f"exactly {len(ch_list)} elements (one per channel), but got shape {isi_th_array.shape}"
)
out: Dict[ChannelId, BurstChannelResult] = {}
pooled_diag: Optional[LogISIThresholdDiagnostics] = None
pooled_isi_th: Optional[float] = None
if method == "logisih" and threshold_scope == "pooled":
pooled_isis: List[np.ndarray] = []
for ch in ch_list:
arr = r.events[ch]
if arr.size < 2:
continue
left = np.searchsorted(arr, tstart_f, side="left")
right = np.searchsorted(arr, tstop_f, side=("right" if inclusive_stop else "left"))
w = arr[left:right]
if w.size >= 2:
pooled_isis.append(np.diff(w).astype(np.float64, copy=False))
pooled_concat = np.concatenate(pooled_isis) if pooled_isis else np.asarray([], dtype=np.float64)
pooled_diag = choose_isi_threshold_logisih(
pooled_concat,
bins=logisih_bins,
smooth_window=logisih_smooth_window,
fallback=fallback,
)
pooled_isi_th = pooled_diag.isi_th
for i, ch in enumerate(ch_list):
arr = r.events[ch]
left = np.searchsorted(arr, tstart_f, side="left")
right = np.searchsorted(arr, tstop_f, side=("right" if inclusive_stop else "left"))
w = arr[left:right]
diag: Optional[LogISIThresholdDiagnostics] = None
if method == "fixed":
isi_th_ch = float(isi_th)
elif method == "fixed_from_logisih":
isi_th_array = np.asarray(isi_th)
isi_th_ch = float(isi_th_array[i])
else: # method == "logisih"
if threshold_scope == "pooled":
diag = pooled_diag
isi_th_ch = float(pooled_isi_th) if pooled_isi_th is not None else float("nan")
else:
isi_vals = np.diff(w).astype(np.float64, copy=False) if w.size >= 2 else np.asarray([], dtype=np.float64)
diag = choose_isi_threshold_logisih(
isi_vals,
bins=logisih_bins,
smooth_window=logisih_smooth_window,
fallback=fallback,
)
isi_th_ch = float(diag.isi_th)
# If isi_th_ch is NaN (e.g., no data), no bursts will be detected.
if not np.isfinite(isi_th_ch):
bursts = []
else:
bursts = _detect_bursts_from_times(w, isi_th=isi_th_ch, min_spikes=int(min_spikes))
out[ch] = BurstChannelResult(channel=ch, isi_th=float(isi_th_ch), bursts=bursts, diagnostics=diag)
method_str = (
f"fixed_isi_th={float(isi_th)}"
if method == "fixed"
else "fixed_from_logisih(per_channel_thresholds)"
if method == "fixed_from_logisih"
else f"logisih(scope={threshold_scope}, bins={logisih_bins}, smooth={logisih_smooth_window}, fallback_q={float(fallback)})"
)
return BurstDetectionResult(
per_channel=out,
method=method_str,
tstart=float(tstart_f),
tstop=float(tstop_f),
channels=list(ch_list),
)