from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Sequence
import numpy as np
from tad.raster import Raster, ChannelId
from tad.metrics.utils import _select_channels, _infer_window
[docs]
@dataclass(frozen=True)
class PSTHResult:
"""
Post-Stimulus Time Histogram (PSTH) result.
The PSTH is computed by aligning spikes to each stimulus time, binning the
relative spike times in a window [-t_pre, t_post), and averaging across stimuli.
Parameters
----------
t
Bin times (centers if time_mode="center", else left edges), shape (n_bins,).
bin_edges
Bin edges in relative time (seconds), shape (n_bins+1,).
counts
Mean spike counts per bin per stimulus, shape (n_channels, n_bins).
rate_hz
Mean firing rate per bin (Hz), i.e. counts / dt, shape (n_channels, n_bins).
dt
Bin width (seconds).
t_pre
Pre-stimulus window length (seconds).
t_post
Post-stimulus window length (seconds).
stim_times_used
Stimulus times actually used (after filtering), shape (n_stim_used,).
channels
Channel IDs corresponding to counts rows.
"""
t: np.ndarray
bin_edges: np.ndarray
counts: np.ndarray
rate_hz: np.ndarray
dt: float
t_pre: float
t_post: float
stim_times_used: np.ndarray
channels: List[ChannelId]
@property
def population_counts(self) -> np.ndarray:
"""
Population PSTH in counts/bin/stimulus (sum across channels).
Returns
-------
ndarray, shape (n_bins,)
"""
return self.counts.sum(axis=0)
@property
def population_rate_hz(self) -> np.ndarray:
"""
Population PSTH in Hz (sum across channels).
Returns
-------
ndarray, shape (n_bins,)
"""
return self.rate_hz.sum(axis=0)
[docs]
def compute_psth(
r: Raster,
stim_times: Sequence[float],
*,
dt: float,
t_pre: float,
t_post: float,
channels: Optional[Sequence[ChannelId]] = None,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
inclusive_stop: bool = False,
inclusive_zero: bool = True,
time_mode: str = "center",
dtype: np.dtype = np.dtype(np.float64),
) -> PSTHResult:
"""
Compute PSTH aligned to stimulus onsets.
This implements the paper definition:
P(t) = (1/N) * sum_i ST(t - t_i^stim)
in discrete form by binning relative spike times per stimulus and averaging.
Parameters
----------
r
Input raster.
stim_times
Iterable of stimulus onset times (seconds).
dt
Bin width (seconds).
t_pre
Pre-stimulus window length (seconds). Window starts at -t_pre.
t_post
Post-stimulus window length (seconds). Window ends at +t_post (exclusive).
channels
Optional subset/order of channels to include.
tstart, tstop
Optional global time window for admissible stimuli. If None, inferred from raster.
A stimulus is used only if its full PSTH window lies within [tstart, tstop)
(or [tstart, tstop] if inclusive_stop=True).
inclusive_stop
Right-bound policy for the admissible stimulus window.
inclusive_zero
If True, spikes exactly at stimulus time (relative time 0) are included.
This affects the left/right searchsorted side used for slicing.
time_mode
- "center": return bin centers in `t`
- "left": return left bin edges in `t`
dtype
Floating dtype for output arrays.
Returns
-------
PSTHResult
PSTH counts and rates per channel.
Raises
------
ValueError
If parameters are invalid.
"""
dt_f = float(dt)
if dt_f <= 0:
raise ValueError("dt must be > 0.")
t_pre_f = float(t_pre)
t_post_f = float(t_post)
if t_pre_f < 0 or t_post_f <= 0:
raise ValueError("t_pre must be >= 0 and t_post must be > 0.")
ch_list = _select_channels(r, channels)
tstart_f, tstop_f = _infer_window(r, ch_list, tstart, tstop)
stim = np.asarray(list(stim_times), dtype=np.float64)
stim = stim[np.isfinite(stim)]
if stim.size == 0:
# Return empty PSTH with correct shapes
edges = np.arange(-t_pre_f, t_post_f + 1e-12, dt_f, dtype=np.float64)
if edges[-1] < t_post_f:
edges = np.append(edges, t_post_f)
n_bins = edges.size - 1
counts = np.zeros((len(ch_list), n_bins), dtype=dtype)
rate = counts / dt_f
t = 0.5 * (edges[:-1] + edges[1:]) if time_mode == "center" else edges[:-1].copy()
return PSTHResult(
t=t.astype(np.float64, copy=False),
bin_edges=edges,
counts=counts,
rate_hz=rate,
dt=dt_f,
t_pre=t_pre_f,
t_post=t_post_f,
stim_times_used=np.asarray([], dtype=np.float64),
channels=list(ch_list),
)
# Keep only stimuli whose whole window lies in [tstart, tstop) (or inclusive stop)
win_left = stim - t_pre_f
win_right = stim + t_post_f
if inclusive_stop:
keep = (win_left >= tstart_f) & (win_right <= tstop_f)
else:
keep = (win_left >= tstart_f) & (win_right < tstop_f)
stim_used = stim[keep]
stim_used.sort()
# Bin edges in relative time
edges = np.arange(-t_pre_f, t_post_f + 1e-12, dt_f, dtype=np.float64)
if edges[-1] < t_post_f:
edges = np.append(edges, t_post_f)
n_bins = edges.size - 1
counts_acc = np.zeros((len(ch_list), n_bins), dtype=np.float64)
# Searchsorted sides for windowing
# - left bound: include spikes at stim - t_pre always
# - right bound: exclude stim + t_post unless inclusive_stop True in raster-window sense
# - spike at stim time: controlled by inclusive_zero
side_left = "left" if inclusive_zero else "right"
side_right = "right" if inclusive_stop else "left"
# Compute PSTH
for s in stim_used:
abs_left = s - t_pre_f
abs_right = s + t_post_f
for ci, ch in enumerate(ch_list):
arr = r.events[ch]
if arr.size == 0:
continue
# Slice spikes in absolute time window, then shift to relative time
left_idx = np.searchsorted(arr, abs_left, side=side_left)
right_idx = np.searchsorted(arr, abs_right, side=side_right)
w = arr[left_idx:right_idx]
if w.size == 0:
continue
rel = w.astype(np.float64, copy=False) - s
h, _ = np.histogram(rel, bins=edges)
counts_acc[ci, :] += h
n_stim = int(stim_used.size)
if n_stim > 0:
counts = counts_acc / float(n_stim)
else:
counts = counts_acc # all zeros
counts = counts.astype(dtype, copy=False)
rate = (counts / dt_f).astype(dtype, copy=False)
if time_mode == "center":
t = 0.5 * (edges[:-1] + edges[1:])
elif time_mode == "left":
t = edges[:-1].copy()
else:
raise ValueError(f"time_mode must be 'center' or 'left', got {time_mode!r}.")
return PSTHResult(
t=t.astype(np.float64, copy=False),
bin_edges=edges,
counts=counts,
rate_hz=rate,
dt=dt_f,
t_pre=t_pre_f,
t_post=t_post_f,
stim_times_used=stim_used,
channels=list(ch_list),
)