Source code for tad.metrics.rates

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Optional, Sequence

import numpy as np

from tad.raster import Raster, ChannelId


[docs] @dataclass(frozen=True) class FiringRateCurveResult: """ Result of a binned firing-rate curve computation. Parameters ---------- t Bin times according to `time_mode` (centers by default), shape (n_bins,). fr_ch Per-channel firing rate per bin (Hz), shape (n_channels, n_bins). fr_pop Population firing rate per bin (Hz), i.e. sum over channels, shape (n_bins,). dt Bin width (seconds or your time unit). tstart Analysis window start (actual binning start). tstop Analysis window stop (actual binning stop = last edge). channels Channel IDs corresponding to fr_ch rows. bin_edges Bin edges used, shape (n_bins+1,). """ t: np.ndarray fr_ch: np.ndarray fr_pop: np.ndarray dt: float tstart: float tstop: float channels: List[ChannelId] bin_edges: np.ndarray
[docs] def firing_rate_curve( r: Raster, dt: float, *, tstart: Optional[float] = None, tstop: Optional[float] = None, channels: Optional[Sequence[ChannelId]] = None, inclusive_stop: bool = False, time_mode: str = "center", ) -> FiringRateCurveResult: """ Compute binned firing-rate curves from a Raster. This is the natural “time-resolved” extension of FR = N/T from the references: counts in bins divided by bin width. Parameters ---------- r Input raster. dt Bin width. tstart, tstop Optional time window. If None, inferred by Raster.bin_counts. channels Optional subset/order of channels. inclusive_stop If True, include events exactly at tstop. time_mode How to represent time points: - "left": left bin edges (starts), shape (n_bins,) - "center": bin centers, shape (n_bins,) (default) Returns ------- FiringRateCurveResult Contains per-channel and population firing-rate curves. Raises ------ ValueError If `time_mode` is not recognized. """ edges, counts, ch_list = r.bin_counts( dt, tstart=tstart, tstop=tstop, channels=channels, inclusive_stop=inclusive_stop, ) dt_f = float(dt) fr_ch = counts.astype(np.float64, copy=False) / dt_f # (n_ch, n_bins) fr_pop = counts.sum(axis=0).astype(np.float64, copy=False) / dt_f # (n_bins,) if time_mode == "left": t = edges[:-1].astype(np.float64, copy=False) elif time_mode == "center": t = (0.5 * (edges[:-1] + edges[1:])).astype(np.float64, copy=False) else: raise ValueError(f"time_mode must be 'left' or 'center', got {time_mode!r}.") tstart_used = float(edges[0]) if edges.size else 0.0 tstop_used = float(edges[-1]) if edges.size else 0.0 return FiringRateCurveResult( t=t, fr_ch=fr_ch, fr_pop=fr_pop, dt=dt_f, tstart=tstart_used, tstop=tstop_used, channels=list(ch_list), bin_edges=edges, )