from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import json
# Import Triggers and TimeSlot for proper trigger reconstruction
try:
from .Triggers import Triggers, TimeSlot
except ImportError:
# Graceful fallback if Triggers module is not available
Triggers = None
TimeSlot = None
try:
import h5py # optional dependency
except ImportError: # pragma: no cover
h5py = None
ChannelId = Union[int, str]
[docs]
@dataclass
class Raster:
"""
Store and manipulate rasterized (event-based) signals across channels.
This class is primarily designed for neuroscience-style spike rasters, but it
generalizes to any multi-channel event timestamp data in dynamical systems.
Internally, timestamps are stored per channel as **sorted 1D NumPy arrays**
(dtype typically float). This makes time-window slicing efficient using
``np.searchsorted`` and keeps representation compact.
Parameters
----------
events
Mapping from channel IDs to 1D arrays (or array-like) of event times.
Arrays will be coerced to ``dtype``, flattened, and sorted.
dtype
NumPy dtype used to store timestamps.
allow_new_channels
If True, missing channels referenced by methods (e.g., ``insert``) are
created automatically. If False, referencing a missing channel raises
``KeyError``.
triggers
Optional Triggers object or list of trigger dicts representing time intervals
(e.g., stimulation periods). Can be a Triggers object from Triggers.py or
a raw list of dicts for backward compatibility.
"""
events: Dict[ChannelId, np.ndarray] = field(default_factory=dict)
amplitudes: Dict[ChannelId, np.ndarray] = field(default_factory=dict)
triggers: Optional[Union[Any, List[Dict[str, Any]]]] = field(default=None) # Can be Triggers object or list of dicts
dtype: np.dtype = np.dtype(np.float64)
allow_new_channels: bool = True
[docs]
def __post_init__(self) -> None:
"""
Normalize internal storage after dataclass initialization.
Coerces each channel's timestamps to a 1D NumPy array of ``self.dtype``
and sorts them in ascending order.
Notes
-----
- If ``events`` contains non-1D array-likes, they are flattened.
- Empty channels are represented by empty 1D arrays.
"""
normalized: Dict[ChannelId, np.ndarray] = {}
for ch, ts in self.events.items():
arr = np.asarray(ts, dtype=self.dtype).ravel()
if arr.size:
arr.sort()
normalized[ch] = arr
self.events = normalized
normalized_amplitudes: Dict[ChannelId, np.ndarray] = {}
for ch, amps in self.amplitudes.items():
arr = np.asarray(amps, dtype=np.float64).ravel()
if ch not in self.events:
raise ValueError(f"Amplitude channel {ch!r} is not present in events.")
if arr.shape[0] != self.events[ch].shape[0]:
raise ValueError(
f"Amplitude length mismatch for channel {ch!r}: "
f"{arr.shape[0]} amplitudes vs {self.events[ch].shape[0]} events."
)
normalized_amplitudes[ch] = arr
self.amplitudes = normalized_amplitudes
if self.triggers is not None:
# Accept both Triggers objects and list of dicts
if Triggers is not None and isinstance(self.triggers, Triggers):
# Valid Triggers object, no further validation needed
pass
elif isinstance(self.triggers, list):
# Validate list of dicts
for item in self.triggers:
if not isinstance(item, dict):
raise ValueError("Each trigger entry must be a dict.")
else:
raise ValueError("Raster.triggers must be a Triggers object, a list of dicts, or None.")
# ---------------------------------------------------------------------
# Constructors / basic accessors
# ---------------------------------------------------------------------
[docs]
@classmethod
def empty(
cls,
channels: Optional[Iterable[ChannelId]] = None,
*,
dtype: np.dtype = np.dtype(np.float64),
allow_new_channels: bool = True,
) -> "Raster":
"""
Create an empty raster, optionally pre-defining channels.
Parameters
----------
channels
Iterable of channel IDs to create with no events.
dtype
NumPy dtype used to store timestamps.
allow_new_channels
If True, missing channels referenced by methods can be created
automatically.
Returns
-------
Raster
A new empty Raster instance.
"""
r = cls(events={}, dtype=dtype, allow_new_channels=allow_new_channels)
if channels is not None:
for ch in channels:
r.events[ch] = np.asarray([], dtype=r.dtype)
return r
[docs]
def channels(self) -> List[ChannelId]:
"""
Return the list of channel IDs currently stored.
Returns
-------
list
List of channel IDs.
"""
return list(self.events.keys())
[docs]
def n_channels(self) -> int:
"""
Return the number of channels currently stored.
Returns
-------
int
Number of channels.
"""
return len(self.events)
[docs]
def copy(self) -> "Raster":
"""
Return a deep copy of this Raster (arrays are copied).
Returns
-------
Raster
Deep copy of the raster.
"""
out = Raster.empty(channels=self.channels(), dtype=self.dtype, allow_new_channels=self.allow_new_channels)
for ch, arr in self.events.items():
out.events[ch] = arr.copy()
for ch, amps in self.amplitudes.items():
out.amplitudes[ch] = amps.copy()
if self.triggers is not None:
# Handle both Triggers objects and dict lists
if Triggers is not None and isinstance(self.triggers, Triggers):
# Deep copy Triggers object
from copy import deepcopy
out.triggers = deepcopy(self.triggers)
elif isinstance(self.triggers, list):
# Copy list of dicts
out.triggers = [dict(t) for t in self.triggers]
else:
out.triggers = self.triggers
return out
# ----------------------------------------------------------------------
# Methods for saving/loading from disk (e.g., json or HDF5)
# ----------------------------------------------------------------------
@staticmethod
def _convert_triggers_to_dicts(triggers: Optional[Any]) -> Optional[List[Dict[str, Any]]]:
"""
Convert Triggers object to a list of dicts for serialization.
If triggers is already a list of dicts or None, returns as-is.
If it's a Triggers object, extracts the TimeSlot information.
Parameters
----------
triggers
A Triggers object, list of dicts, or None
Returns
-------
list of dict or None
List of trigger dicts suitable for JSON serialization
"""
if triggers is None:
return None
# If it's already a list of dicts, return as-is
if isinstance(triggers, list):
return triggers
# If it's a Triggers object, convert it
if Triggers is not None and isinstance(triggers, Triggers):
try:
return [
{
"start": float(slot.start),
"end": float(slot.end),
"ID": slot.ID,
"description": slot.description,
"blank": slot.blank,
}
for slot in triggers.slots
]
except Exception:
return None
return None
@staticmethod
def _convert_trigger_dicts_to_triggers(trigger_data: Optional[List[Dict[str, Any]]]) -> Optional[Union["Triggers", List[Dict[str, Any]]]]:
"""
Convert a list of trigger dictionaries to a proper Triggers object.
If Triggers class is available, reconstructs TimeSlot objects and returns
a Triggers container. Otherwise, returns the raw list for backward compatibility.
Parameters
----------
trigger_data
List of trigger dictionaries with keys: start, end, ID, blank, description
Returns
-------
Triggers or list of dict
If Triggers module is available, returns Triggers object with TimeSlot instances.
Otherwise, returns the original list of dicts.
"""
if trigger_data is None or not isinstance(trigger_data, list):
return trigger_data
# If Triggers module not available, return raw data
if Triggers is None or TimeSlot is None:
return trigger_data
try:
# Reconstruct TimeSlot objects from dicts
slots = []
for trig_dict in trigger_data:
slot = TimeSlot(
start=float(trig_dict.get("start", 0.0)),
end=float(trig_dict.get("end", 0.0)),
ID=trig_dict.get("ID", None),
description=trig_dict.get("description", None),
blank=trig_dict.get("blank", False),
)
slots.append(slot)
return Triggers(slots=slots)
except Exception:
# If conversion fails, return raw data
return trigger_data
[docs]
def get_trigger_intervals(self) -> List[Tuple[float, float]]:
"""
Return trigger intervals as a list of (start, end) tuples.
Supports both `Triggers` objects and backward-compatible lists of
trigger dictionaries.
"""
intervals: List[Tuple[float, float]] = []
if self.triggers is None:
return intervals
if Triggers is not None and isinstance(self.triggers, Triggers):
for slot in self.triggers.slots:
try:
start = float(slot.start)
end = float(slot.end)
except Exception:
continue
if start < end:
intervals.append((start, end))
return intervals
if isinstance(self.triggers, list):
for item in self.triggers:
if not isinstance(item, dict):
continue
try:
start = float(item["start"])
end = float(item["end"])
except (KeyError, TypeError, ValueError):
continue
if start < end:
intervals.append((start, end))
return intervals
return intervals
[docs]
def save(
self,
path: Union[str, Path],
*,
h5: bool = True,
group: str = "/raster",
indent: int = 2,
compression: Union[None, str] = "gzip",
compression_opts: int = 4,
overwrite: bool = True,
save_amplitudes: bool = False,
save_triggers: bool = False
) -> None:
"""
Save this Raster to disk in either HDF5 or JSON format.
Parameters
----------
path
Output file path.
h5
If True (default), save to HDF5 using `h5py`. If False, save to JSON.
group
HDF5 group path (only used if `h5=True`).
indent
JSON indentation level (only used if `h5=False`). Use None for compact JSON.
compression
HDF5 compression (only used if `h5=True`).
compression_opts
HDF5 compression level/options (only used if `h5=True`).
overwrite
If True, overwrite the target HDF5 group if it exists (only used if `h5=True`).
save_amplitudes
If True, saves the amplitudes of the spikes
Raises
------
ImportError
If `h5=True` but `h5py` is not installed.
TypeError
If channel IDs are not int or str.
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# ---------- JSON branch ----------
if not h5:
channels: List[Dict[str, Any]] = []
for ch, arr in self.events.items():
if isinstance(ch, (int, np.integer)):
ch_type = "int"
ch_id: Union[int, str] = int(ch)
elif isinstance(ch, str):
ch_type = "str"
ch_id = ch
else:
raise TypeError(
f"Unsupported channel id type {type(ch)} for JSON serialization. Use int or str."
)
record: Dict[str, Any] = {
"id": ch_id,
"type": ch_type,
"times": np.asarray(arr, dtype=float).ravel().tolist(),
}
if save_amplitudes and ch in self.amplitudes:
record["amplitudes"] = np.asarray(self.amplitudes[ch], dtype = float).ravel().tolist()
channels.append(record)
payload = {
"schema": "Raster",
"version": 1,
"dtype": np.dtype(self.dtype).str,
"allow_new_channels": bool(self.allow_new_channels),
"channels": channels,
}
if save_triggers and self.triggers is not None:
triggers_as_dicts = self._convert_triggers_to_dicts(self.triggers)
if triggers_as_dicts:
payload["triggers"] = triggers_as_dicts
path.write_text(json.dumps(payload, indent=indent), encoding="utf-8")
return
# ---------- HDF5 branch ----------
if h5py is None:
raise ImportError("h5py is required for HDF5 I/O. Install with `pip install h5py`.")
with h5py.File(path, "a") as f:
if group in f:
if not overwrite:
raise FileExistsError(f"Group {group!r} already exists in {str(path)!r} (overwrite=False).")
del f[group]
g = f.create_group(group)
g.attrs["schema"] = "Raster"
g.attrs["version"] = 1
g.attrs["dtype"] = np.dtype(self.dtype).str
g.attrs["allow_new_channels"] = bool(self.allow_new_channels)
g.attrs["has_amplitudes"] = bool(save_amplitudes and bool(self.amplitudes))
g.attrs["has_triggers"] = bool(save_triggers and self.triggers is not None)
# Preserve channel id types using two parallel datasets
ch_list = list(self.events.keys())
types: List[str] = []
ids_as_str: List[str] = []
for ch in ch_list:
if isinstance(ch, (int, np.integer)):
types.append("int")
ids_as_str.append(str(int(ch)))
elif isinstance(ch, str):
types.append("str")
ids_as_str.append(ch)
else:
raise TypeError(f"Unsupported channel id type {type(ch)} for HDF5 serialization. Use int or str.")
dt_vlen_str = h5py.string_dtype(encoding="utf-8")
g.create_dataset("channel_types", data=np.asarray(types, dtype=object), dtype=dt_vlen_str)
g.create_dataset("channel_ids", data=np.asarray(ids_as_str, dtype=object), dtype=dt_vlen_str)
tg = g.create_group("times")
for i, ch in enumerate(ch_list):
arr = np.asarray(self.events[ch], dtype=self.dtype).ravel()
# keep invariant: sorted
if arr.size and not np.all(arr[:-1] <= arr[1:]):
arr = np.sort(arr)
tg.create_dataset(
name=str(i),
data=arr,
dtype=np.dtype(self.dtype),
compression=compression,
compression_opts=compression_opts if compression else None,
shuffle=True if compression else False,
)
if save_amplitudes and self.amplitudes:
ag = g.create_group("amplitudes")
for i, ch in enumerate(ch_list):
amps = np.asarray(self.amplitudes.get(ch, []), dtype=np.float64).ravel()
ag.create_dataset(
name=str(i),
data=amps,
dtype=np.dtype(np.float64),
compression=compression,
compression_opts=compression_opts if compression else None,
shuffle=True if compression else False,
)
if save_triggers and self.triggers is not None:
triggers_as_dicts = self._convert_triggers_to_dicts(self.triggers)
if triggers_as_dicts:
triggers_json = json.dumps(triggers_as_dicts)
g.create_dataset(
"triggers",
data=np.asarray(triggers_json, dtype=object),
dtype=dt_vlen_str,
)
[docs]
@classmethod
def load(
cls,
path: Union[str, Path],
*,
h5: bool = True,
group: str = "/raster",
) -> "Raster":
"""
Load a Raster from disk in either HDF5 or JSON format.
Parameters
----------
path
Input file path.
h5
If True (default), load from HDF5 using `h5py`. If False, load from JSON.
group
HDF5 group path (only used if `h5=True`).
Returns
-------
Raster
Loaded Raster instance.
Raises
------
ImportError
If `h5=True` but `h5py` is not installed.
ValueError
If schema/version are not recognized.
KeyError
If the HDF5 group does not exist.
"""
path = Path(path)
# ---------- JSON branch ----------
if not h5:
payload = json.loads(path.read_text(encoding="utf-8"))
if payload.get("schema") != "Raster":
raise ValueError(f"Unrecognized schema: {payload.get('schema')!r}")
if int(payload.get("version", -1)) != 1:
raise ValueError(f"Unrecognized Raster JSON version: {payload.get('version')!r}")
dtype = np.dtype(payload.get("dtype", np.dtype(np.float64).str))
allow_new_channels = bool(payload.get("allow_new_channels", True))
events: Dict[ChannelId, np.ndarray] = {}
amplitudes: Dict[ChannelId, np.ndarray] = {}
triggers = None
for rec in payload.get("channels", []):
ty = rec.get("type")
cid = rec.get("id")
if ty == "int":
ch: ChannelId = int(cid)
elif ty == "str":
ch = str(cid)
else:
raise ValueError(f"Unrecognized channel type in JSON: {ty!r}")
arr = np.asarray(rec.get("times", []), dtype=dtype).ravel()
if arr.size:
arr.sort()
events[ch] = arr
if "amplitudes" in rec:
amplitudes[ch] = np.asarray(rec["amplitudes"], dtype=np.float64).ravel()
triggers_data = payload.get("triggers", None)
triggers = cls._convert_trigger_dicts_to_triggers(triggers_data)
return cls(events=events, amplitudes=amplitudes, triggers=triggers, dtype=dtype, allow_new_channels=allow_new_channels)
# ---------- HDF5 branch ----------
if h5py is None:
raise ImportError("h5py is required for HDF5 I/O. Install with `pip install h5py`.")
with h5py.File(path, "r") as f:
if group not in f:
raise KeyError(f"Group {group!r} not found in {str(path)!r}.")
g = f[group]
schema = g.attrs.get("schema", None)
version = int(g.attrs.get("version", -1))
if schema != "Raster":
raise ValueError(f"Unrecognized schema: {schema!r}")
if version != 1:
raise ValueError(f"Unrecognized Raster HDF5 version: {version!r}")
dtype = np.dtype(g.attrs.get("dtype", np.dtype(np.float64).str))
allow_new_channels = bool(g.attrs.get("allow_new_channels", True))
types = [t.decode("utf-8") if isinstance(t, bytes) else str(t) for t in g["channel_types"][...]]
ids = [x.decode("utf-8") if isinstance(x, bytes) else str(x) for x in g["channel_ids"][...]]
times_group = g["times"]
amp_group = g.get("amplitudes",None)
triggers = None
if "triggers" in g:
raw = g["triggers"][()]
text = raw.decode("utf-8") if isinstance(raw, bytes) else str(raw)
triggers_data = json.loads(text)
triggers = cls._convert_trigger_dicts_to_triggers(triggers_data)
events: Dict[ChannelId, np.ndarray] = {}
amplitudes: Dict[ChannelId, np.ndarray] = {}
for i, (ty, cid) in enumerate(zip(types, ids)):
if ty == "int":
ch: ChannelId = int(cid)
elif ty == "str":
ch = cid
else:
raise ValueError(f"Unrecognized channel type {ty!r} in HDF5.")
arr = np.asarray(times_group[str(i)][...], dtype=dtype).ravel()
if arr.size:
arr.sort()
events[ch] = arr
if amp_group is not None and str(i) in amp_group:
amplitudes[ch] = np.asarray(amp_group[str(i)][...], dtype=np.float64).ravel()
return cls(events=events, amplitudes=amplitudes, triggers = triggers, dtype=dtype, allow_new_channels=allow_new_channels)
# ---------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------
def _require_channel(self, channel: ChannelId) -> None:
"""
Ensure a channel exists in the raster.
If the channel is absent and ``allow_new_channels`` is True, a new empty
channel is created. Otherwise, a ``KeyError`` is raised.
Parameters
----------
channel
Channel ID to ensure.
Raises
------
KeyError
If the channel does not exist and ``allow_new_channels`` is False.
"""
if channel not in self.events:
if not self.allow_new_channels:
raise KeyError(f"Channel {channel!r} does not exist and allow_new_channels=False.")
self.events[channel] = np.asarray([], dtype=self.dtype)
def _validate_time(self, t: float) -> float:
"""
Validate and coerce a scalar time value.
Parameters
----------
t
Candidate timestamp.
Returns
-------
float
Validated timestamp as Python float.
Raises
------
ValueError
If ``t`` is not finite (NaN or +/-inf).
"""
tf = float(t)
if not np.isfinite(tf):
raise ValueError(f"Timestamp must be finite, got {t!r}.")
return tf
def _validate_time_array(self, t: Union[np.ndarray, Iterable[float]]) -> np.ndarray:
"""
Validate and coerce an array-like of timestamps.
Parameters
----------
t
Array-like of timestamps.
Returns
-------
ndarray
1D NumPy array of dtype ``self.dtype``.
Raises
------
ValueError
If any timestamp is not finite.
"""
arr = np.asarray(list(t) if not isinstance(t, np.ndarray) else t, dtype=self.dtype).ravel()
if arr.size and not np.all(np.isfinite(arr)):
raise ValueError("All timestamps must be finite (no NaN/+/-inf).")
return arr
def _window_indices(
self,
arr: np.ndarray,
tstart: float,
tstop: float,
inclusive_stop: bool,
) -> Tuple[int, int]:
"""
Compute slicing indices for a time window on a sorted array.
Parameters
----------
arr
Sorted 1D array of timestamps.
tstart
Window start time.
tstop
Window stop time.
inclusive_stop
If True, include events at exactly ``tstop``. If False, treat the
interval as half-open.
Returns
-------
left, right
Indices such that ``arr[left:right]`` are events in the window.
Notes
-----
- Uses ``np.searchsorted`` for O(log n) boundary location.
- Requires ``arr`` to be sorted ascending.
"""
left = np.searchsorted(arr, tstart, side="left")
right_side = "right" if inclusive_stop else "left"
right = np.searchsorted(arr, tstop, side=right_side)
return int(left), int(right)
# ---------------------------------------------------------------------
# Channel management
# ---------------------------------------------------------------------
[docs]
def insert_channel(
self,
channel: ChannelId,
times: Optional[Union[np.ndarray, Iterable[float]]] = None,
*,
overwrite: bool = False,
sort: bool = True,
) -> None:
"""
Insert a new channel (optionally with initial events).
Parameters
----------
channel
Channel ID to insert.
times
Optional array-like of timestamps for the channel. If None, the
channel is created empty.
overwrite
If False (default), raises if the channel already exists.
If True, replaces the channel's events with ``times`` (or empty).
sort
If True (default), sorts provided times.
Raises
------
KeyError
If channel exists and ``overwrite=False``.
ValueError
If any timestamp is not finite.
"""
if (channel in self.events) and not overwrite:
raise KeyError(f"Channel {channel!r} already exists (overwrite=False).")
if times is None:
arr = np.asarray([], dtype=self.dtype)
else:
arr = self._validate_time_array(times)
if sort and arr.size:
arr.sort()
self.events[channel] = arr
[docs]
def pop_channel(self, channel: ChannelId) -> np.ndarray:
"""
Remove an entire channel and return its timestamps.
Parameters
----------
channel
Channel ID to remove.
Returns
-------
ndarray
Sorted 1D array of timestamps that were stored in the channel.
Raises
------
KeyError
If the channel does not exist.
"""
if channel not in self.events:
raise KeyError(f"Channel {channel!r} does not exist.")
return self.events.pop(channel)
# ---------------------------------------------------------------------
# Core event manipulation
# ---------------------------------------------------------------------
[docs]
def insert(self, channel: ChannelId, t: float) -> None:
"""
Insert a timestamp into a channel, preserving sorted order.
Parameters
----------
channel
Channel ID.
t
Timestamp to insert.
Notes
-----
- Insertion into a NumPy array is O(n) due to shifting; for very frequent
online insertion at large scale, a buffered strategy can be added later.
- Duplicates are allowed.
"""
tf = self._validate_time(t)
self._require_channel(channel)
arr = self.events[channel]
idx = np.searchsorted(arr, tf, side="right")
self.events[channel] = np.insert(arr, idx, np.asarray(tf, dtype=self.dtype))
[docs]
def insert_timestamparray(
self,
channel: ChannelId,
times: Union[np.ndarray, Iterable[float]],
*,
assume_sorted: bool = False,
sort_result: bool = True,
) -> None:
"""
Insert multiple timestamps into a channel.
Parameters
----------
channel
Channel ID.
times
Array-like timestamps to insert (NumPy array or iterable).
assume_sorted
If True, assumes ``times`` is already sorted ascending (saves a sort).
sort_result
If True (default), ensures the channel remains sorted after insertion.
Raises
------
ValueError
If any timestamp is not finite.
Notes
-----
- Uses concatenation + sort as a simple, robust approach.
- Complexity is O((n+m) log(n+m)) due to sorting; for very large n with small
m, a merge approach could be implemented later.
"""
self._require_channel(channel)
new_arr = self._validate_time_array(times)
if new_arr.size == 0:
return
if not assume_sorted:
new_arr.sort()
old = self.events[channel]
merged = np.concatenate([old, new_arr]).astype(self.dtype, copy=False)
if sort_result and merged.size:
merged.sort()
self.events[channel] = merged
[docs]
def pop(self, channel: ChannelId, index: int = -1) -> float:
"""
Remove and return one timestamp from a channel.
Parameters
----------
channel
Channel ID.
index
Index of the event to remove (default: -1, last event).
Returns
-------
float
The removed timestamp.
Raises
------
IndexError
If the channel has no events.
"""
self._require_channel(channel)
arr = self.events[channel]
if arr.size == 0:
raise IndexError(f"Cannot pop from empty channel {channel!r}.")
t = float(arr[index])
self.events[channel] = np.delete(arr, index)
return t
[docs]
def clear(self, channel: Optional[ChannelId] = None) -> None:
"""
Clear events from one channel or from all channels.
Parameters
----------
channel
If provided, clears only that channel. If None, clears all channels.
"""
if channel is None:
for ch in list(self.events.keys()):
self.events[ch] = np.asarray([], dtype=self.dtype)
else:
self._require_channel(channel)
self.events[channel] = np.asarray([], dtype=self.dtype)
# ---------------------------------------------------------------------
# Time-window operations
# ---------------------------------------------------------------------
[docs]
def between(self, tstart: float, tstop: float, *, inclusive_stop: bool = False) -> "Raster":
"""
Return a new Raster containing events in a time interval.
Parameters
----------
tstart
Start time of the interval.
tstop
Stop time of the interval.
inclusive_stop
If False (default), returns events in [tstart, tstop).
If True, returns events in [tstart, tstop].
Returns
-------
Raster
A new Raster with events restricted to the interval.
Raises
------
ValueError
If ``tstop < tstart`` or either boundary is not finite.
"""
tstart_f = self._validate_time(tstart)
tstop_f = self._validate_time(tstop)
if tstop_f < tstart_f:
raise ValueError(f"tstop ({tstop_f}) must be >= tstart ({tstart_f}).")
out = Raster.empty(channels=self.channels(), dtype=self.dtype, allow_new_channels=self.allow_new_channels)
for ch, arr in self.events.items():
left, right = self._window_indices(arr, tstart_f, tstop_f, inclusive_stop)
out.events[ch] = arr[left:right].copy()
return out
[docs]
def shift(self, dt: float, *, in_place: bool = False) -> "Raster":
"""
Shift all timestamps by a constant offset.
Parameters
----------
dt
Time offset to add to every timestamp (can be negative).
in_place
If True, modifies and returns self. If False, returns a shifted copy.
Returns
-------
Raster
Shifted raster.
Notes
-----
This operation preserves sorted order within each channel because adding
a constant is order-preserving.
"""
dt_f = self._validate_time(dt)
target = self if in_place else self.copy()
for ch, arr in target.events.items():
if arr.size:
target.events[ch] = (arr + dt_f).astype(target.dtype, copy=False)
return target
# ---------------------------------------------------------------------
# Channel operations
# ---------------------------------------------------------------------
[docs]
def relabel_channels(
self,
mapping: Mapping[ChannelId, ChannelId],
*,
on_conflict: str = "raise",
in_place: bool = False,
) -> "Raster":
"""
Relabel channels according to a mapping.
Parameters
----------
mapping
Dictionary-like mapping {old_channel: new_channel}.
Channels not present in mapping are left unchanged.
on_conflict
What to do if multiple old channels map to the same new channel, or
if a new channel already exists.
- "raise": raise ValueError
- "merge": merge events into the target channel (and keep sorted)
in_place
If True, modifies and returns self. If False, returns a relabeled copy.
Returns
-------
Raster
Raster with relabeled channels.
Raises
------
ValueError
If conflicts occur and ``on_conflict="raise"``.
"""
if on_conflict not in {"raise", "merge"}:
raise ValueError(f"on_conflict must be 'raise' or 'merge', got {on_conflict!r}.")
src = self if in_place else self.copy()
new_events: Dict[ChannelId, np.ndarray] = {}
for old_ch, arr in src.events.items():
new_ch = mapping.get(old_ch, old_ch)
if new_ch not in new_events:
new_events[new_ch] = arr.copy()
else:
if on_conflict == "raise":
raise ValueError(f"Channel relabel conflict on target {new_ch!r}.")
merged = np.concatenate([new_events[new_ch], arr])
if merged.size:
merged.sort()
new_events[new_ch] = merged.astype(src.dtype, copy=False)
src.events = new_events
return src
# ---------------------------------------------------------------------
# Multi-raster operations
# ---------------------------------------------------------------------
[docs]
def merge(
self,
other: "Raster",
*,
on_missing_channels: str = "union",
in_place: bool = False,
) -> "Raster":
"""
Merge another Raster into this one (concatenate events per channel).
Parameters
----------
other
Raster to merge into this raster.
on_missing_channels
Channel set handling:
- "union": result contains union of channels from both rasters
- "intersection": result contains only channels present in both
in_place
If True, modifies and returns self. If False, returns a merged copy.
Returns
-------
Raster
Merged raster.
Raises
------
ValueError
If ``on_missing_channels`` is not recognized.
Notes
-----
- Event times are concatenated and then sorted within each channel.
- ``other`` is coerced to ``self.dtype`` in the result.
"""
if on_missing_channels not in {"union", "intersection"}:
raise ValueError(
f"on_missing_channels must be 'union' or 'intersection', got {on_missing_channels!r}."
)
target = self if in_place else self.copy()
if on_missing_channels == "union":
ch_set = set(target.events.keys()) | set(other.events.keys())
else:
ch_set = set(target.events.keys()) & set(other.events.keys())
for ch in ch_set:
a = target.events.get(ch, np.asarray([], dtype=target.dtype))
b = other.events.get(ch, np.asarray([], dtype=other.dtype))
b = np.asarray(b, dtype=target.dtype).ravel()
if a.size == 0:
merged = b.copy()
elif b.size == 0:
merged = a.copy()
else:
merged = np.concatenate([a, b])
merged.sort()
target.events[ch] = merged
if on_missing_channels == "intersection":
target.events = {ch: target.events[ch] for ch in ch_set}
return target
[docs]
def as_arrays(
self,
*,
channels: Optional[Sequence[ChannelId]] = None,
sort_by_time: bool = True,
channel_dtype: Optional[np.dtype] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Export events as concatenated (times, channels) arrays.
Parameters
----------
channels
Subset/order of channels to export. If None, exports all channels in
insertion order.
sort_by_time
If True, sort the concatenated output by time (stable for equal times).
If False, output is grouped by channel in the provided order.
channel_dtype
Optional dtype for returned channel array. If None:
- uses int64 if all channel IDs are ints
- otherwise uses object dtype
Returns
-------
times : ndarray, shape (N,)
Concatenated event times.
ch_ids : ndarray, shape (N,)
Channel identifiers per event.
"""
ch_list = list(channels) if channels is not None else self.channels()
if channel_dtype is None:
if all(isinstance(ch, (int, np.integer)) for ch in ch_list):
channel_dtype = np.dtype(np.int64)
else:
channel_dtype = np.dtype(object)
times_list: List[np.ndarray] = []
ch_list_rep: List[np.ndarray] = []
for ch in ch_list:
self._require_channel(ch)
ts = self.events[ch]
if ts.size == 0:
continue
times_list.append(ts.astype(self.dtype, copy=False))
ch_list_rep.append(np.full(ts.shape, ch, dtype=channel_dtype))
if not times_list:
return (np.asarray([], dtype=self.dtype), np.asarray([], dtype=channel_dtype))
times = np.concatenate(times_list)
ch_ids = np.concatenate(ch_list_rep)
if sort_by_time and times.size:
order = np.argsort(times, kind="stable")
times = times[order]
ch_ids = ch_ids[order]
return times, ch_ids
# ---------------------------------------------------------------------
# Binning method for metrics that require binned data (e.g., firing rates)
# ---------------------------------------------------------------------
[docs]
def bin_counts(
self,
dt: float,
*,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
channels: Optional[Sequence[ChannelId]] = None,
inclusive_stop: bool = False,
dtype: np.dtype = np.dtype(np.int64),
) -> Tuple[np.ndarray, np.ndarray, List[ChannelId]]:
"""
Bin events into fixed-width time bins and return spike counts per channel.
This is a core primitive for downstream metrics (avalanches, binned MI/TE,
PSTH-like quantities, population activity, etc.).
Parameters
----------
dt
Bin width (same units as timestamps). Must be positive and finite.
tstart
Start time of the binning window. If None, inferred as the global
minimum event time across selected channels (or 0.0 if no events).
tstop
Stop time of the binning window. If None, inferred as the global
maximum event time across selected channels (or tstart + dt if no events).
channels
Subset/order of channels to bin. If None, uses all channels in this raster.
The returned count matrix row order follows this list (after optional sorting
if you pass a sorted list yourself).
inclusive_stop
If False (default), treat window as [tstart, tstop) and exclude events at
exactly tstop. If True, include events at exactly tstop.
dtype
Integer dtype for returned counts.
Returns
-------
bin_edges : ndarray, shape (n_bins + 1,)
Bin edge times spanning the window.
counts : ndarray, shape (n_channels, n_bins)
Spike counts per channel per bin.
ch_list : list
Channel IDs corresponding to rows of `counts`.
Raises
------
ValueError
If `dt` is not positive/finite, or if window bounds are invalid.
Notes
-----
- Bins are constructed as consecutive intervals of width `dt` starting at `tstart`.
- If `inclusive_stop=False`, an event at exactly `tstop` is excluded.
If True, it is included by nudging the histogram right edge.
- Uses `np.searchsorted`-based histogramming per channel for correctness
and speed on sorted spike times.
"""
# ---- validate dt ----
dt_f = float(dt)
if not np.isfinite(dt_f) or dt_f <= 0.0:
raise ValueError(f"dt must be positive and finite, got {dt!r}.")
# ---- choose channels ----
ch_list = list(channels) if channels is not None else self.channels()
# Ensure channels exist (and do not create new ones silently if allow_new_channels=False)
for ch in ch_list:
self._require_channel(ch)
# ---- infer window if needed ----
# Collect min/max across selected channels
mins = []
maxs = []
for ch in ch_list:
arr = self.events[ch]
if arr.size:
mins.append(arr[0])
maxs.append(arr[-1])
if tstart is None:
tstart_f = float(np.min(mins)) if mins else 0.0
else:
tstart_f = self._validate_time(tstart)
if tstop is None:
if maxs:
tstop_f = float(np.max(maxs))
else:
tstop_f = tstart_f + dt_f
else:
tstop_f = self._validate_time(tstop)
if tstop_f < tstart_f:
raise ValueError(f"tstop ({tstop_f}) must be >= tstart ({tstart_f}).")
# If the window length is 0, return an empty binning (0 bins)
length = tstop_f - tstart_f
if length == 0.0:
bin_edges = np.asarray([tstart_f], dtype=self.dtype)
counts = np.zeros((len(ch_list), 0), dtype=dtype)
return bin_edges, counts, ch_list
# ---- build bin edges ----
# Number of full bins that cover [tstart, tstop) or [tstart, tstop]
# We want edges = tstart + k*dt for k=0..n_bins
n_bins = int(np.ceil(length / dt_f))
bin_edges = tstart_f + dt_f * np.arange(n_bins + 1, dtype=float)
# Ensure last edge reaches or exceeds tstop; ceil ensures it.
# For inclusive_stop, include events exactly at tstop by pushing right edge slightly if needed.
if inclusive_stop:
# Make sure tstop is inside the rightmost edge, and count events at exactly tstop.
# Using nextafter avoids altering bins beyond floating precision necessities.
if tstop_f == bin_edges[-1]:
bin_edges[-1] = np.nextafter(bin_edges[-1], np.inf)
else:
# If tstop falls before the last edge, still include it naturally.
pass
else:
# Exclude tstop exactly: if last edge equals tstop, histogram excludes right edge by default.
# If last edge > tstop, events at exactly tstop are still excluded due to being < last_edge,
# but they might be counted if tstop < last_edge. So we explicitly treat tstop as the limit
# by masking events >= tstop later.
pass
# ---- count per channel ----
counts = np.zeros((len(ch_list), n_bins), dtype=dtype)
for i, ch in enumerate(ch_list):
arr = self.events[ch]
if arr.size == 0:
continue
# Apply window restriction efficiently on sorted arrays
left = np.searchsorted(arr, tstart_f, side="left")
if inclusive_stop:
right = np.searchsorted(arr, tstop_f, side="right")
else:
right = np.searchsorted(arr, tstop_f, side="left")
w = arr[left:right]
if w.size == 0:
continue
# Histogram against bin edges
# np.histogram uses bins as [edge_j, edge_{j+1}) except last which is [..] (implementation detail),
# but we controlled inclusive_stop with nextafter when needed.
h, _ = np.histogram(w, bins=bin_edges)
counts[i, :] = h.astype(dtype, copy=False)
return bin_edges.astype(self.dtype, copy=False), counts, ch_list
# ---------------------------------------------------------------------
# Visualization
# ---------------------------------------------------------------------
[docs]
def plot(
self,
*,
ax: Optional[plt.Axes] = None,
channels: Optional[Sequence[ChannelId]] = None,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
inclusive_stop: bool = False,
y_gap: float = 1.0,
tick_halfheight: Optional[float] = None,
linewidth: float = 1.0,
sort_channels: bool = True,
show: bool = False,
plot_triggers: bool = False,
trigger_color: str = "red",
trigger_alpha: float = 0.5,
trigger_linewidth: float = 2.0,
) -> plt.Axes:
"""
Plot the raster using matplotlib.
Each event is drawn as a vertical tick at its timestamp, arranged by
channel along the y-axis.
Parameters
----------
ax
Matplotlib Axes to draw on. If None, creates a new figure and axes.
channels
Subset/order of channels to plot. If None, plot all channels.
tstart
Optional time window start. If provided with ``tstop``, the plot
is restricted to that interval.
tstop
Optional time window stop.
inclusive_stop
If True, include events at exactly ``tstop`` when windowing.
y_gap
Vertical spacing between channels.
tick_halfheight
Half-height of the vertical tick marks. If None, defaults to
``0.4 * y_gap``.
linewidth
Line width of tick marks.
sort_channels
If True, try to sort channel IDs for display.
show
If True, calls ``plt.show()`` before returning.
plot_triggers
If True, draw trigger intervals from ``self.triggers`` as horizontal
lines above the raster.
trigger_color
Color used for trigger interval lines.
trigger_alpha
Transparency of the trigger lines.
trigger_linewidth
Line width used for trigger interval lines.
Returns
-------
matplotlib.axes.Axes
The Axes containing the plot.
"""
if ax is None:
_, ax = plt.subplots()
ch_list = list(channels) if channels is not None else self.channels()
if sort_channels:
try:
ch_list = sorted(ch_list)
except TypeError:
pass
raster = self
if tstart is not None or tstop is not None:
all_times = np.concatenate(
[v for v in self.events.values() if v.size] or [np.asarray([], dtype=self.dtype)]
)
if tstart is None:
tstart = float(all_times.min()) if all_times.size else 0.0
if tstop is None:
tstop = float(all_times.max()) if all_times.size else (tstart + 1.0)
raster = self.between(float(tstart), float(tstop), inclusive_stop=inclusive_stop)
if tick_halfheight is None:
tick_halfheight = 0.4 * y_gap
y_positions = np.arange(len(ch_list), dtype=float) * y_gap
for i, ch in enumerate(ch_list):
raster._require_channel(ch)
ts = raster.events[ch]
if ts.size == 0:
continue
y = y_positions[i]
ax.vlines(ts, y - tick_halfheight, y + tick_halfheight, linewidth=linewidth)
if plot_triggers and self.triggers is not None:
trigger_intervals = self.get_trigger_intervals()
if trigger_intervals:
for start, end in trigger_intervals:
if tstart is not None:
start = max(start, float(tstart))
if tstop is not None:
end = min(end, float(tstop))
if end <= start:
continue
ax.axvspan(
start,
end,
color=trigger_color,
alpha=trigger_alpha,
linewidth=0,
zorder=3,
)
ax.set_yticks(y_positions)
ax.set_yticklabels([str(ch) for ch in ch_list])
ax.set_ylabel("Channel")
ax.set_xlabel("Time")
ax.set_title("Raster")
if tstart is not None and tstop is not None:
ax.set_xlim(float(tstart), float(tstop))
ax.margins(x=0.01)
if show:
plt.show()
return ax