Source code for tad.plotting.psth

from __future__ import annotations

from typing import Optional, Sequence

import matplotlib.pyplot as plt
import numpy as np

from tad.metrics.psth import PSTHResult


[docs] def plot_psth_lines( psth: PSTHResult, *, ax: Optional[plt.Axes] = None, channels: Optional[Sequence[int]] = None, kind: str = "rate", show: bool = True, ) -> plt.Axes: """ Plot PSTH as lines for selected channels. Parameters ---------- psth PSTHResult from `compute_psth`. ax Matplotlib axes. channels Row indices into psth.counts/psth.rate_hz. If None, plots up to first 10. kind "rate" (Hz) or "counts" (counts/bin/stim). show If True, calls plt.show(). Returns ------- Axes """ if ax is None: _, ax = plt.subplots(figsize=(9, 3)) Y = psth.rate_hz if kind == "rate" else psth.counts if channels is None: idx = np.arange(min(10, Y.shape[0])) else: idx = np.asarray(list(channels), dtype=int) for i in idx: label = str(psth.channels[i]) if i < len(psth.channels) else str(i) ax.plot(psth.t, Y[i, :], label=label) ax.axvline(0.0, linestyle="--", alpha=0.6) ax.set_xlabel("Time from stimulus (s)") ax.set_ylabel("Rate (Hz)" if kind == "rate" else "Counts / bin / stim") ax.set_title("PSTH (per channel)") if idx.size <= 12: ax.legend(ncol=2, fontsize=8) if show: plt.show() return ax
[docs] def plot_psth_heatmap( psth: PSTHResult, *, ax: Optional[plt.Axes] = None, kind: str = "rate", robust: bool = True, show_colorbar: bool = True, show: bool = True, ) -> plt.Axes: """ Plot PSTH as a channel × time heatmap. Parameters ---------- psth PSTHResult from `compute_psth`. ax Matplotlib axes. kind "rate" or "counts". robust If True, clip color range to 2–98 percentiles for readability. show_colorbar Add a colorbar. show If True, calls plt.show(). Returns ------- Axes """ if ax is None: _, ax = plt.subplots(figsize=(9, 4)) Y = psth.rate_hz if kind == "rate" else psth.counts vmin = vmax = None if robust and Y.size: vmin, vmax = np.percentile(Y, [2, 98]) # Use edges for extent x0 = float(psth.bin_edges[0]) x1 = float(psth.bin_edges[-1]) im = ax.imshow( Y, aspect="auto", origin="lower", extent=[x0, x1, 0, Y.shape[0]], interpolation="nearest", vmin=vmin, vmax=vmax, ) ax.axvline(0.0, linestyle="--", alpha=0.6) ax.set_xlabel("Time from stimulus (s)") ax.set_ylabel("Channel (row index)") ax.set_title("PSTH heatmap") if show_colorbar: cb = plt.colorbar(im, ax=ax) cb.set_label("Rate (Hz)" if kind == "rate" else "Counts / bin / stim") if show: plt.show() return ax