from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
import numpy as np
from tad.metrics.psth import PSTHResult
from tad.raster import Raster, ChannelId
from tad.metrics.utils import _select_channels, _infer_window
[docs]
@dataclass(frozen=True)
class EvokedPeakResult:
"""
Scalar evoked-response metrics derived from a PSTH.
Parameters
----------
channels
Channel IDs corresponding to rows.
baseline_hz
Baseline rate in Hz (mean over baseline window).
peak_hz
Peak rate in Hz (max over response window).
peak_latency_s
Time of peak in seconds (relative to stimulus).
peak_minus_baseline_hz
Peak above baseline (Hz).
auc_above_baseline
Area above baseline over response window (spikes per stimulus),
computed as sum(max(rate-baseline,0))*dt.
"""
channels: List[ChannelId]
baseline_hz: np.ndarray
peak_hz: np.ndarray
peak_latency_s: np.ndarray
peak_minus_baseline_hz: np.ndarray
auc_above_baseline: np.ndarray
[docs]
def evoked_peak_metrics(
psth: PSTHResult,
*,
baseline_window: Tuple[float, float] = (-0.050, 0.0),
response_window: Tuple[float, float] = (0.0, 0.050),
rectify: bool = True,
) -> EvokedPeakResult:
"""
Compute per-channel evoked-response scalar metrics from PSTH.
Parameters
----------
psth
PSTHResult as returned by `compute_psth`.
baseline_window
(t0, t1) window for baseline mean rate (seconds, relative to stimulus).
response_window
(t0, t1) window where peak/latency/AUC are computed.
rectify
If True, AUC uses max(rate-baseline, 0). If False, uses signed (rate-baseline).
Returns
-------
EvokedPeakResult
Per-channel scalar metrics.
Raises
------
ValueError
If windows are invalid or contain no bins.
"""
t = psth.t
R = psth.rate_hz
dt = float(psth.dt)
b0, b1 = map(float, baseline_window)
r0, r1 = map(float, response_window)
if not (b0 < b1 and r0 < r1):
raise ValueError("baseline_window and response_window must satisfy t0 < t1.")
bmask = (t >= b0) & (t < b1)
rmask = (t >= r0) & (t < r1)
if not np.any(bmask):
raise ValueError("baseline_window contains no PSTH bins. Adjust baseline_window or dt.")
if not np.any(rmask):
raise ValueError("response_window contains no PSTH bins. Adjust response_window or dt.")
baseline = R[:, bmask].mean(axis=1)
R_resp = R[:, rmask]
t_resp = t[rmask]
# Peak and latency (argmax)
peak_idx = np.argmax(R_resp, axis=1)
peak = R_resp[np.arange(R_resp.shape[0]), peak_idx]
latency = t_resp[peak_idx]
delta = R_resp - baseline[:, None]
if rectify:
delta = np.maximum(delta, 0.0)
auc = delta.sum(axis=1) * dt # spikes per stimulus (since rate*sec)
return EvokedPeakResult(
channels=list(psth.channels),
baseline_hz=baseline.astype(np.float64, copy=False),
peak_hz=peak.astype(np.float64, copy=False),
peak_latency_s=latency.astype(np.float64, copy=False),
peak_minus_baseline_hz=(peak - baseline).astype(np.float64, copy=False),
auc_above_baseline=auc.astype(np.float64, copy=False),
)
[docs]
def response_probability(
r: Raster,
stim_times: Sequence[float],
*,
t0: float = 0.0,
t1: float = 0.050,
channels: Optional[Sequence[ChannelId]] = None,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
inclusive_stop: bool = False,
inclusive_zero: bool = True,
) -> Tuple[List[ChannelId], np.ndarray]:
"""
Estimate per-channel response probability to a stimulus.
Definition used here:
For each channel c and each stimulus i, define an indicator:
I_{c,i} = 1 if channel c has >=1 spike in [stim_i + t0, stim_i + t1)
else 0
Response probability per channel:
p_c = (1/N) * sum_i I_{c,i}
This complements PSTH peak metrics and is a standard evoked-response quantity.
Parameters
----------
r
Input raster.
stim_times
Stimulus onset times.
t0, t1
Response window bounds relative to stimulus (seconds), with [t0, t1).
channels
Optional channels subset/order.
tstart, tstop
Optional admissible stimulus window; stimuli are kept only if the full
[stim+t0, stim+t1) lies within [tstart, tstop) (or inclusive_stop).
inclusive_stop
Right-bound policy for admissible stimulus window.
inclusive_zero
If True, include spikes exactly at stim+t0 when t0==0 (uses searchsorted side).
Returns
-------
channels_list, p
channels_list is list of ChannelId, p is ndarray shape (n_channels,) with probabilities.
"""
if not (float(t0) < float(t1)):
raise ValueError("Require t0 < t1.")
ch_list = _select_channels(r, channels)
tstart_f, tstop_f = _infer_window(r, ch_list, tstart, tstop)
stim = np.asarray(list(stim_times), dtype=np.float64)
stim = stim[np.isfinite(stim)]
if stim.size == 0:
return list(ch_list), np.zeros((len(ch_list),), dtype=np.float64)
win_left = stim + float(t0)
win_right = stim + float(t1)
if inclusive_stop:
keep = (win_left >= tstart_f) & (win_right <= tstop_f)
else:
keep = (win_left >= tstart_f) & (win_right < tstop_f)
stim_used = stim[keep]
stim_used.sort()
n_stim = int(stim_used.size)
if n_stim == 0:
return list(ch_list), np.zeros((len(ch_list),), dtype=np.float64)
side_left = "left" if inclusive_zero else "right"
side_right = "right" if inclusive_stop else "left"
resp = np.zeros((len(ch_list),), dtype=np.float64)
for ci, ch in enumerate(ch_list):
arr = r.events[ch]
if arr.size == 0:
continue
hits = 0
for s in stim_used:
a = s + float(t0)
b = s + float(t1)
left_idx = np.searchsorted(arr, a, side=side_left)
right_idx = np.searchsorted(arr, b, side=side_right)
if right_idx > left_idx:
hits += 1
resp[ci] = hits / float(n_stim)
return list(ch_list), resp