Source code for tad.metrics.utils

from __future__ import annotations

from typing import List, Optional, Sequence, Tuple

import numpy as np

from tad.raster import Raster, ChannelId


def _select_channels(r: Raster, channels: Optional[Sequence[ChannelId]]) -> List[ChannelId]:
    """
    Select a channel list for computations.

    Parameters
    ----------
    r
        Input raster.
    channels
        Channels to include. If None, uses all channels in the raster.

    Returns
    -------
    list
        Channel IDs in the order that will be used.
    """
    ch_list = list(channels) if channels is not None else r.channels()
    for ch in ch_list:
        r._require_channel(ch)
    return ch_list


def _infer_window(
    r: Raster,
    ch_list: Sequence[ChannelId],
    tstart: Optional[float],
    tstop: Optional[float],
) -> Tuple[float, float]:
    """
    Infer (tstart, tstop) if missing, based on min/max event times.

    Parameters
    ----------
    r
        Input raster.
    ch_list
        Channels considered.
    tstart, tstop
        Optional window bounds.

    Returns
    -------
    tstart_f, tstop_f
        Window bounds as floats.

    Notes
    -----
    - If no events exist, defaults to (0.0, 0.0) unless a bound is provided.
    """
    mins = []
    maxs = []
    for ch in ch_list:
        arr = r.events[ch]
        if arr.size:
            mins.append(arr[0])
            maxs.append(arr[-1])

    if tstart is None:
        tstart_f = float(np.min(mins)) if mins else 0.0
    else:
        tstart_f = r._validate_time(tstart)

    if tstop is None:
        tstop_f = float(np.max(maxs)) if maxs else tstart_f
    else:
        tstop_f = r._validate_time(tstop)

    if tstop_f < tstart_f:
        raise ValueError(f"tstop ({tstop_f}) must be >= tstart ({tstart_f}).")

    return tstart_f, tstop_f


[docs] def pooled_spike_times( r: Raster, *, channels: Optional[Sequence[ChannelId]] = None, tstart: Optional[float] = None, tstop: Optional[float] = None, inclusive_stop: bool = False, ) -> np.ndarray: """ Pool spike times across channels into a single sorted array. Parameters ---------- r Input raster. channels Channels to include. If None, uses all channels. tstart, tstop Optional time window. inclusive_stop Window right-bound policy: [tstart,tstop) if False, [tstart,tstop] if True. Returns ------- ndarray, shape (N,) Sorted pooled times in the requested window. """ ch_list = _select_channels(r, channels) tstart_f, tstop_f = _infer_window(r, ch_list, tstart, tstop) times = [] for ch in ch_list: arr = r.events[ch] if arr.size == 0: continue left = np.searchsorted(arr, tstart_f, side="left") right = np.searchsorted(arr, tstop_f, side=("right" if inclusive_stop else "left")) w = arr[left:right] if w.size: times.append(w) if not times: return np.asarray([], dtype=r.dtype) out = np.concatenate(times).astype(r.dtype, copy=False) out.sort() return out