Source code for tad.processing_history

"""
processing_history.py

Lightweight provenance tracking for data processing steps.

Design goals
------------
- JSON-first persistence (portable, diffable)
- Minimal intrusion into existing processing code
- Captures ordered operations (B) and key state snapshots (C)
- Lightweight dataset metadata (A-lite): name, sampling rate, channels, file stats
"""

from __future__ import annotations

import base64
import dataclasses
import datetime as _dt
import hashlib
import json
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional


JsonDict = Dict[str, Any]


def _utc_now_iso() -> str:
    return _dt.datetime.now(tz=_dt.timezone.utc).isoformat()


def _canonical_dumps(obj: Any) -> str:
    """Deterministic JSON string used for hashing."""
    return json.dumps(obj, sort_keys=True, separators=(",", ":"), ensure_ascii=False)


def _sha256_text(text: str) -> str:
    return hashlib.sha256(text.encode("utf-8")).hexdigest()


def _safe_json_value(x: Any) -> Any:
    """
    Convert common non-JSON objects to JSON-friendly representations.

    Notes
    -----
    keep this conservative: it should not accidentally serialize large arrays.
    """
    # Basic scalars
    if x is None or isinstance(x, (bool, int, float, str)):
        return x

    # Simple containers
    if isinstance(x, (list, tuple)):
        return [_safe_json_value(v) for v in x]
    if isinstance(x, dict):
        return {str(k): _safe_json_value(v) for k, v in x.items()}

    # Numpy types / arrays (without importing numpy as hard dependency here)
    mod = getattr(type(x), "__module__", "")
    if mod.startswith("numpy"):
        # scalar -> python scalar
        if hasattr(x, "item") and callable(getattr(x, "item")):
            try:
                return x.item()
            except Exception:
                pass

        # array -> small list only; otherwise require caller to pre-encode
        if hasattr(x, "shape") and hasattr(x, "size"):
            size = int(getattr(x, "size", 0))
            if size <= 256 and hasattr(x, "tolist"):
                return x.tolist()
            return {"__ndarray__": True, "note": f"array too large to inline (size={size})"}

    # Fallback
    return str(x)


[docs] def encode_bool_mask(mask_bytes: bytes) -> str: """ Base64 encode a boolean mask serialized as raw bytes. Parameters ---------- mask_bytes : bytes Raw bytes representation of a boolean mask (e.g., numpy packbits output). Returns ------- str Base64-encoded string. """ return base64.b64encode(mask_bytes).decode("ascii")
[docs] def decode_bool_mask(encoded: str) -> bytes: """ Decode a base64-encoded boolean mask payload. Parameters ---------- encoded : str Base64-encoded mask string. Returns ------- bytes Decoded bytes. """ return base64.b64decode(encoded.encode("ascii"))
[docs] def tracked_operation( name: Optional[str] = None, *, track: bool = True, include_result_artifacts: Optional[Callable[[Any], Dict[str, Any]]] = None, ) -> Callable: """ Decorator to record a method call into provenance. It supports two targets (if present on `self`): - self.processing_history : ProcessingHistory-like object with snapshot_state/state_hash/record - self.history : list[dict] (lightweight generic log) Parameters ---------- name : str, optional Operation name. Defaults to function name. track : bool, default=True Whether to track this operation. include_result_artifacts : callable, optional Function called with the method return value; must return JSON-friendly dict to attach as `artifacts`. """ def _decorator(func: Callable) -> Callable: op_name = name or func.__name__ def _wrapped(self, *args, **kwargs): if not track: return func(self, *args, **kwargs) # --- resolve history targets (robustly) --- ph = getattr(self, "processing_history", None) lightweight = getattr(self, "history", None) has_ph = ( ph is not None and hasattr(ph, "snapshot_state") and hasattr(ph, "state_hash") and hasattr(ph, "record") ) has_lightweight = isinstance(lightweight, list) # If nothing to log to, behave like a no-op decorator if (not has_ph) and (not has_lightweight): return func(self, *args, **kwargs) # --- before snapshot/hash (ProcessingHistory only) --- before_hash = None if has_ph: try: before_snapshot = ph.snapshot_state() before_hash = ph.state_hash(before_snapshot) except Exception: before_hash = None # run operation result = func(self, *args, **kwargs) # --- after snapshot/hash (ProcessingHistory only) --- after_hash = None after_snapshot = None if has_ph: try: after_snapshot = ph.snapshot_state() after_hash = ph.state_hash(after_snapshot) except Exception: after_hash = None after_snapshot = None # conservative params capture params = _safe_json_value({"args": list(args), "kwargs": dict(kwargs)}) artifacts: Dict[str, Any] = {} if include_result_artifacts is not None: try: artifacts = include_result_artifacts(result) or {} except Exception: artifacts = {"note": "artifact extraction failed"} artifacts = _safe_json_value(artifacts) # --- record to ProcessingHistory (rich) --- if has_ph: # Optional summary derived from snapshot if the getter provides these keys summary: Dict[str, Any] = {} if isinstance(after_snapshot, dict): if "mask_n_kept" in after_snapshot: summary["mask_n_kept"] = after_snapshot.get("mask_n_kept") if "excluded_intervals_n" in after_snapshot: summary["excluded_intervals_n"] = after_snapshot.get("excluded_intervals_n") try: ph.record( op_name, params=params, state_before=before_hash, state_after=after_hash, summary=_safe_json_value(summary), artifacts=artifacts, ) except Exception: # Don't break the pipeline because provenance failed pass # --- record to lightweight list[dict] --- if has_lightweight: try: lightweight.append( _safe_json_value( { "name": op_name, "timestamp_utc": _utc_now_iso(), "params": params, "state_before": before_hash, "state_after": after_hash, "artifacts": artifacts, } ) ) except Exception: pass return result _wrapped.__name__ = func.__name__ _wrapped.__doc__ = func.__doc__ _wrapped.__qualname__ = func.__qualname__ return _wrapped return _decorator
[docs] @dataclass class DatasetInfo: """ Lightweight dataset identity/metadata (A-lite). Attributes ---------- fname : str Path (as provided) to the source file. basename : str Base filename. sampling_frequency : float, optional Sampling frequency in Hz. stream_id : int, optional Stream identifier used by the reader (if applicable). channel_ids : list, optional Channel identifiers. electrode_labels : list, optional Electrode labels (if present). file_size_bytes : int, optional File size in bytes (if path exists). file_mtime_utc : str, optional File modification time in UTC ISO string. """ fname: str basename: str sampling_frequency: Optional[float] = None stream_id: Optional[int] = None channel_ids: Optional[List[Any]] = None electrode_labels: Optional[List[Any]] = None file_size_bytes: Optional[int] = None file_mtime_utc: Optional[str] = None
[docs] @classmethod def from_path( cls, fname: str, sampling_frequency: Optional[float] = None, stream_id: Optional[int] = None, channel_ids: Optional[List[Any]] = None, electrode_labels: Optional[List[Any]] = None, ) -> "DatasetInfo": basename = os.path.basename(fname) size = None mtime = None if os.path.exists(fname): st = os.stat(fname) size = int(st.st_size) mtime = _dt.datetime.fromtimestamp(st.st_mtime, tz=_dt.timezone.utc).isoformat() return cls( fname=fname, basename=basename, sampling_frequency=sampling_frequency, stream_id=stream_id, channel_ids=channel_ids, electrode_labels=electrode_labels, file_size_bytes=size, file_mtime_utc=mtime, )
[docs] def to_dict(self) -> JsonDict: return dataclasses.asdict(self)
[docs] @dataclass class OperationRecord: """ A single processing operation record (B). Attributes ---------- name : str Operation name (e.g., "apply_filter"). timestamp_utc : str UTC timestamp ISO. params : dict JSON-friendly parameters. state_before : str, optional Hash of state snapshot before operation. state_after : str, optional Hash of state snapshot after operation. summary : dict, optional Small human-friendly summary (counts, etc.). artifacts : dict, optional Any extra “result metadata” (e.g., raster t-range, n_spikes). """ name: str timestamp_utc: str params: JsonDict = field(default_factory=dict) state_before: Optional[str] = None state_after: Optional[str] = None summary: JsonDict = field(default_factory=dict) artifacts: JsonDict = field(default_factory=dict)
[docs] def to_dict(self) -> JsonDict: return dataclasses.asdict(self)
[docs] @dataclass class ProcessingHistory: """ Provenance container for a processing session. Parameters ---------- dataset : DatasetInfo Lightweight dataset metadata. session_id : str, optional Stable id for this history instance. If None, a hash is generated. """ dataset: DatasetInfo operations: List[OperationRecord] = field(default_factory=list) created_utc: str = field(default_factory=_utc_now_iso) session_id: str = field(default_factory=lambda: _sha256_text(_utc_now_iso())) # Not persisted: function to snapshot state from a live object _state_getter: Optional[Callable[[], JsonDict]] = field(default=None, repr=False, compare=False)
[docs] def set_state_getter(self, getter: Callable[[], JsonDict]) -> None: """Attach a callback that returns a JSON-friendly state snapshot.""" self._state_getter = getter
[docs] def snapshot_state(self) -> JsonDict: """Return a state snapshot using the attached getter, or empty dict.""" if self._state_getter is None: return {} snap = self._state_getter() return _safe_json_value(snap)
[docs] def state_hash(self, snapshot: Optional[JsonDict] = None) -> str: """ Compute a stable hash for a given snapshot (or current snapshot if None). """ if snapshot is None: snapshot = self.snapshot_state() return _sha256_text(_canonical_dumps(snapshot))
[docs] def record( self, name: str, params: Optional[JsonDict] = None, *, state_before: Optional[str] = None, state_after: Optional[str] = None, summary: Optional[JsonDict] = None, artifacts: Optional[JsonDict] = None, ) -> None: """ Append an operation record. Parameters ---------- name : str Operation name. params : dict, optional Operation parameters (JSON-friendly). state_before : str, optional State hash before op. state_after : str, optional State hash after op. summary : dict, optional Human-friendly summary. artifacts : dict, optional Output-related metadata. """ rec = OperationRecord( name=name, timestamp_utc=_utc_now_iso(), params=_safe_json_value(params or {}), state_before=state_before, state_after=state_after, summary=_safe_json_value(summary or {}), artifacts=_safe_json_value(artifacts or {}), ) self.operations.append(rec)
[docs] def to_dict(self) -> JsonDict: """Serialize to a JSON-friendly dict.""" return { "schema": "tad.processing_history.v1", "created_utc": self.created_utc, "session_id": self.session_id, "dataset": self.dataset.to_dict(), "operations": [op.to_dict() for op in self.operations], }
[docs] def to_json(self, indent: int = 2) -> str: """Serialize to JSON string.""" return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
[docs] def save_json(self, path: str, indent: int = 2) -> None: """Write history JSON to disk.""" with open(path, "w", encoding="utf-8") as f: f.write(self.to_json(indent=indent))
[docs] @classmethod def from_dict(cls, d: JsonDict) -> "ProcessingHistory": dataset = DatasetInfo(**d["dataset"]) hist = cls(dataset=dataset, created_utc=d.get("created_utc", _utc_now_iso()), session_id=d.get("session_id", "")) for op in d.get("operations", []): hist.operations.append(OperationRecord(**op)) return hist
[docs] @classmethod def load_json(cls, path: str) -> "ProcessingHistory": """Load history JSON from disk.""" with open(path, "r", encoding="utf-8") as f: d = json.load(f) return cls.from_dict(d)