from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
from tad.raster import Raster, ChannelId
from tad.metrics.utils import _select_channels, _infer_window, pooled_spike_times
[docs]
@dataclass(frozen=True)
class ISIResult:
"""
Result container for ISI extraction.
Parameters
----------
isi
ISIs. If mode="per_channel", this is a dict {channel_id: isi_array}.
If mode="pooled", this is a single 1D array.
mode
"per_channel" or "pooled".
tstart, tstop
Window used for extraction.
channels
Channel IDs included in extraction (order).
"""
isi: Union[Dict[ChannelId, np.ndarray], np.ndarray]
mode: Literal["per_channel", "pooled"]
tstart: float
tstop: float
channels: List[ChannelId]
[docs]
def isi(
r: Raster,
*,
channels: Optional[Sequence[ChannelId]] = None,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
inclusive_stop: bool = False,
mode: Literal["per_channel", "pooled"] = "per_channel",
) -> ISIResult:
"""
Compute inter-spike intervals (ISIs) from a raster.
Parameters
----------
r
Input raster.
channels
Channels to include. If None, uses all channels.
tstart, tstop
Optional analysis window. If None, inferred from data.
inclusive_stop
Right-bound policy: [tstart,tstop) if False, [tstart,tstop] if True.
mode
- "per_channel": compute ISIs separately for each channel
- "pooled": pool spike times across channels (global ordering) then compute ISIs
Returns
-------
ISIResult
ISIs and metadata.
Notes
-----
- Per-channel ISIs are computed from each channel’s sorted spike times.
- Pooled ISIs correspond to the IEI concept used in the references when pooling
across electrodes (global event train).
"""
ch_list = _select_channels(r, channels)
tstart_f, tstop_f = _infer_window(r, ch_list, tstart, tstop)
if mode == "pooled":
t = pooled_spike_times(
r,
channels=ch_list,
tstart=tstart_f,
tstop=tstop_f,
inclusive_stop=inclusive_stop,
)
if t.size < 2:
out = np.asarray([], dtype=r.dtype)
else:
out = np.diff(t).astype(r.dtype, copy=False)
return ISIResult(
isi=out,
mode="pooled",
tstart=tstart_f,
tstop=tstop_f,
channels=list(ch_list),
)
if mode == "per_channel":
out: Dict[ChannelId, np.ndarray] = {}
for ch in ch_list:
arr = r.events[ch]
if arr.size < 2:
out[ch] = np.asarray([], dtype=r.dtype)
continue
left = np.searchsorted(arr, tstart_f, side="left")
right = np.searchsorted(arr, tstop_f, side=("right" if inclusive_stop else "left"))
w = arr[left:right]
if w.size < 2:
out[ch] = np.asarray([], dtype=r.dtype)
else:
out[ch] = np.diff(w).astype(r.dtype, copy=False)
return ISIResult(
isi=out,
mode="per_channel",
tstart=tstart_f,
tstop=tstop_f,
channels=list(ch_list),
)
raise ValueError(f"mode must be 'per_channel' or 'pooled', got {mode!r}.")
[docs]
def isih(
isi_values: np.ndarray,
*,
bins: Union[int, np.ndarray] = 50,
density: bool = False,
log: bool = False,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Compute an ISI histogram (ISIH), optionally in log domain.
Parameters
----------
isi_values
1D array of ISIs (>0).
bins
Number of bins or explicit bin edges (passed to np.histogram).
density
If True, return a probability density (area = 1). If False, return counts.
log
If True, histogram log10(ISI) rather than ISI.
Returns
-------
centers : ndarray
Bin centers (in ISI units if log=False; in log10 units if log=True).
hist : ndarray
Histogram values (counts or density).
"""
x = np.asarray(isi_values, dtype=np.float64).ravel()
x = x[np.isfinite(x) & (x > 0.0)]
if x.size == 0:
return np.asarray([], dtype=np.float64), np.asarray([], dtype=np.float64)
if log:
x = np.log10(x)
hist, edges = np.histogram(x, bins=bins, density=density)
centers = 0.5 * (edges[:-1] + edges[1:])
return centers, hist.astype(np.float64, copy=False)