Source code for tad.plotting.burst

from __future__ import annotations

from typing import Optional, Sequence

import matplotlib.pyplot as plt

from tad.metrics import BurstDetectionResult
from tad.raster import ChannelId


[docs] def plot_bursts_spans( bursts: BurstDetectionResult, *, ax: Optional[plt.Axes] = None, channels: Optional[Sequence[ChannelId]] = None, ymin: float = 0.0, ymax: float = 1.0, alpha: float = 0.15, show: bool = False, ) -> plt.Axes: """ Overlay burst intervals as shaded spans on an axis. Parameters ---------- bursts Output of `tad.metrics.bursts.detect_bursts`. ax Matplotlib Axes to draw on. If None, creates a new figure/axes. channels Optional subset of channels to plot. If None, plot all channels present in `bursts.channels`. ymin, ymax Vertical span fraction in axis coordinates, passed to `ax.axvspan`. Default (0,1) shades full vertical extent. alpha Span transparency. show If True, calls plt.show(). Default False so you can compose plots. Returns ------- matplotlib.axes.Axes Axis with spans added. Notes ----- - This helper intentionally draws spans in axis coordinates so it can be used both on raster plots and on rate plots. - Coloring is uniform by default for clarity; if you later want per-channel colors, we can add an option. """ if ax is None: _, ax = plt.subplots() if channels is None: ch_list = bursts.channels else: ch_list = list(channels) for ch in ch_list: cr = bursts.per_channel.get(ch, None) if cr is None: continue for b in cr.bursts: ax.axvspan(b.start, b.end, ymin=ymin, ymax=ymax, alpha=alpha) if show: plt.show() return ax
[docs] def plot_bursts_on_raster( r, bursts: BurstDetectionResult, *, ax: Optional[plt.Axes] = None, tstart: Optional[float] = None, tstop: Optional[float] = None, channels: Optional[Sequence[ChannelId]] = None, alpha: float = 0.15, show: bool = False, ) -> plt.Axes: """ Convenience: plot a raster then overlay burst spans. Parameters ---------- r Raster instance (expects `.plot(...)` method). bursts BurstDetectionResult. ax Axes to draw on. If None, creates a new figure/axes. tstart, tstop Time window forwarded to Raster.plot. channels Channels to overlay bursts for (burst detection channels, not raster plot channels). If None, overlays all channels present in bursts. alpha Span transparency. show If True, calls plt.show(). Returns ------- matplotlib.axes.Axes Axes containing raster + burst overlays. """ if ax is None: _, ax = plt.subplots(figsize=(10, 4)) r.plot(ax=ax, tstart=tstart, tstop=tstop, show=False) plot_bursts_spans(bursts, ax=ax, channels=channels, alpha=alpha, show=False) if show: plt.show() return ax