Source code for mllm_shap.shap.core.telemetry

"""Telemetry and metrics collection for SHAP explainability."""

from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from json import dumps
from logging import Logger
from time import perf_counter
from typing import Any, Generator

from ...utils.logger import get_logger

logger: Logger = get_logger(__name__)


[docs] @dataclass(frozen=True) class CacheMetrics: """Metrics for cache operations.""" hits: int = 0 """Number of cache hits.""" misses: int = 0 """Number of cache misses.""" @property def total(self) -> int: """Total cache operations.""" return self.hits + self.misses @property def hit_rate(self) -> float: """Cache hit rate (0-1).""" if self.total == 0: return 0.0 return self.hits / self.total
[docs] def to_dict(self) -> dict[str, Any]: """Convert metrics to dictionary.""" return { "hits": self.hits, "misses": self.misses, "total": self.total, "hit_rate": self.hit_rate, }
[docs] @dataclass(frozen=True) class MaskMetrics: """Metrics for mask generation and deduplication.""" generated: int = 0 """Total masks generated.""" unique: int = 0 """Total unique masks (after dedup).""" invalid: int = 0 """Total invalid/empty masks.""" @property def dedup_rate(self) -> float: """Rate of duplicates caught (0-1).""" if self.generated == 0: return 0.0 duplicates = self.generated - self.unique return duplicates / self.generated
[docs] def to_dict(self) -> dict[str, Any]: """Convert metrics to dictionary.""" return { "generated": self.generated, "unique": self.unique, "invalid": self.invalid, "dedup_rate": self.dedup_rate, }
[docs] @dataclass(frozen=True) class TimingMetrics: """Metrics for per-stage timing.""" sampling_ms: float = 0.0 """Time spent in mask sampling (milliseconds).""" dedup_ms: float = 0.0 """Time spent in deduplication checks (milliseconds).""" masking_ms: float = 0.0 """Time spent in mask preparation (milliseconds).""" model_ms: float = 0.0 """Time spent in model inference (milliseconds).""" scoring_ms: float = 0.0 """Time spent in SHAP value computation (milliseconds).""" @property def total_ms(self) -> float: """Total time across all stages.""" return ( self.sampling_ms + self.dedup_ms + self.masking_ms + self.model_ms + self.scoring_ms )
[docs] def to_dict(self) -> dict[str, Any]: """Convert metrics to dictionary.""" return { "sampling_ms": self.sampling_ms, "dedup_ms": self.dedup_ms, "masking_ms": self.masking_ms, "model_ms": self.model_ms, "scoring_ms": self.scoring_ms, "total_ms": self.total_ms, }
[docs] class StageTimer: """Context manager for timing code blocks with probe integration.""" def __init__(self, probe: "TelemetryProbe", stage: str) -> None: """ Initialize StageTimer. Args: probe: TelemetryProbe to record timing to. stage: Name of the stage being timed (sampling, dedup, masking, model, scoring). """ self.probe = probe self.stage = stage self.start_time: float | None = None def __enter__(self) -> "StageTimer": """Start timing.""" self.start_time = perf_counter() return self def __exit__(self, *args: Any) -> None: """Stop timing and record to probe.""" if self.start_time is not None: elapsed_ms = (perf_counter() - self.start_time) * 1000 self.probe.record_timing(self.stage, elapsed_ms)
[docs] @contextmanager def measure(self) -> Generator[None, None, None]: """Alternate context manager interface.""" with self: yield
[docs] @dataclass class TelemetryData: """Container for all telemetry data collected during a run.""" cache_metrics: CacheMetrics = field(default_factory=CacheMetrics) """Cache operation metrics.""" mask_metrics: MaskMetrics = field(default_factory=MaskMetrics) """Mask generation metrics.""" timing_metrics: TimingMetrics = field(default_factory=TimingMetrics) """Per-stage timing metrics.""" custom_metrics: dict[str, Any] = field(default_factory=dict) """Custom metrics dictionary."""
[docs] def to_dict(self) -> dict[str, Any]: """Convert all telemetry to dictionary.""" return { "cache": self.cache_metrics.to_dict(), "masks": self.mask_metrics.to_dict(), "timing": self.timing_metrics.to_dict(), "custom": self.custom_metrics, }
[docs] class ProbeSink(ABC): """Abstract base class for telemetry sinks."""
[docs] @abstractmethod def record_cache_operation(self, is_hit: bool) -> None: """Record a cache operation (hit or miss)."""
[docs] @abstractmethod def record_mask_generated(self, is_unique: bool, is_invalid: bool = False) -> None: """Record a generated mask and whether it was unique/invalid."""
[docs] @abstractmethod def record_timing(self, stage: str, elapsed_ms: float) -> None: """Record timing for a stage (sampling, dedup, masking, model, scoring)."""
[docs] @abstractmethod def record_custom_metric(self, key: str, value: Any) -> None: """Record a custom metric."""
[docs] @abstractmethod def get_metrics(self) -> TelemetryData: """Get collected telemetry data."""
[docs] def reset(self) -> None: """Reset collected metrics. Optional override."""
[docs] class LogProbeSink(ProbeSink): """Telemetry sink that logs metrics to logger.""" def __init__(self, verbose: bool = False) -> None: """ Initialize LogProbeSink. Args: verbose: If True, log every operation. If False, only log summary. """ self.verbose = verbose self._cache_hits = 0 self._cache_misses = 0 self._masks_generated = 0 self._masks_unique = 0 self._masks_invalid = 0 self._sampling_ms = 0.0 self._dedup_ms = 0.0 self._masking_ms = 0.0 self._model_ms = 0.0 self._scoring_ms = 0.0 self._custom_metrics: dict[str, Any] = {}
[docs] def record_cache_operation(self, is_hit: bool) -> None: """Record a cache operation.""" if is_hit: self._cache_hits += 1 if self.verbose: logger.debug( "Mask cache hit: reusing previously computed SHAP values " "(cumulative hits: %d)", self._cache_hits, ) else: self._cache_misses += 1 if self.verbose: logger.debug( "Mask cache miss: mask not in cache, will compute SHAP values " "(cumulative misses: %d)", self._cache_misses, )
[docs] def record_mask_generated(self, is_unique: bool, is_invalid: bool = False) -> None: """Record a generated mask.""" self._masks_generated += 1 if is_invalid: self._masks_invalid += 1 if self.verbose: logger.debug( "Invalid mask generated: failed validation (e.g., no tokens to explain). " "Cumulative invalid masks: %d/%d", self._masks_invalid, self._masks_generated, ) elif is_unique: self._masks_unique += 1 if self.verbose: logger.debug( "Unique mask generated: new mask pattern added to cache. " "Cumulative unique masks: %d/%d", self._masks_unique, self._masks_generated, ) else: if self.verbose: logger.debug( "Duplicate mask skipped: mask pattern already cached, " "will use existing results (duplicates save computation). " "Cumulative duplicates: %d/%d", self._masks_generated - self._masks_unique - self._masks_invalid, self._masks_generated, )
[docs] def record_timing(self, stage: str, elapsed_ms: float) -> None: """Record timing for a stage.""" if stage == "sampling": self._sampling_ms += elapsed_ms elif stage == "dedup": self._dedup_ms += elapsed_ms elif stage == "masking": self._masking_ms += elapsed_ms elif stage == "model": self._model_ms += elapsed_ms elif stage == "scoring": self._scoring_ms += elapsed_ms if self.verbose: stage_descriptions = { "sampling": "sampling mask generation strategy", "dedup": "mask deduplication and caching", "masking": "applying mask to chat", "model": "model inference on masked input", "scoring": "scoring computation", } description = stage_descriptions.get(stage, stage) logger.debug( "Stage completed: %s took %.2f ms", description, elapsed_ms, )
[docs] def record_custom_metric(self, key: str, value: Any) -> None: """Record a custom metric.""" self._custom_metrics[key] = value if self.verbose: logger.debug( "Custom telemetry metric recorded: %s = %s (type: %s)", key, value, type(value).__name__, )
[docs] def get_metrics(self) -> TelemetryData: """Get collected telemetry data.""" cache_metrics = CacheMetrics(hits=self._cache_hits, misses=self._cache_misses) mask_metrics = MaskMetrics( generated=self._masks_generated, unique=self._masks_unique, invalid=self._masks_invalid, ) timing_metrics = TimingMetrics( sampling_ms=self._sampling_ms, dedup_ms=self._dedup_ms, masking_ms=self._masking_ms, model_ms=self._model_ms, scoring_ms=self._scoring_ms, ) return TelemetryData( cache_metrics=cache_metrics, mask_metrics=mask_metrics, timing_metrics=timing_metrics, custom_metrics=self._custom_metrics.copy(), )
[docs] def reset(self) -> None: """Reset collected metrics.""" self._cache_hits = 0 self._cache_misses = 0 self._masks_generated = 0 self._masks_unique = 0 self._masks_invalid = 0 self._sampling_ms = 0.0 self._dedup_ms = 0.0 self._masking_ms = 0.0 self._model_ms = 0.0 self._scoring_ms = 0.0 self._custom_metrics.clear()
[docs] class JSONProbeSink(ProbeSink): """Telemetry sink that collects metrics for JSON serialization.""" def __init__(self) -> None: """Initialize JSONProbeSink.""" self._cache_hits = 0 self._cache_misses = 0 self._masks_generated = 0 self._masks_unique = 0 self._masks_invalid = 0 self._sampling_ms = 0.0 self._dedup_ms = 0.0 self._masking_ms = 0.0 self._model_ms = 0.0 self._scoring_ms = 0.0 self._custom_metrics: dict[str, Any] = {}
[docs] def record_cache_operation(self, is_hit: bool) -> None: """Record a cache operation.""" if is_hit: self._cache_hits += 1 else: self._cache_misses += 1
[docs] def record_mask_generated(self, is_unique: bool, is_invalid: bool = False) -> None: """Record a generated mask.""" self._masks_generated += 1 if is_invalid: self._masks_invalid += 1 elif is_unique: self._masks_unique += 1
[docs] def record_timing(self, stage: str, elapsed_ms: float) -> None: """Record timing for a stage.""" if stage == "sampling": self._sampling_ms += elapsed_ms elif stage == "dedup": self._dedup_ms += elapsed_ms elif stage == "masking": self._masking_ms += elapsed_ms elif stage == "model": self._model_ms += elapsed_ms elif stage == "scoring": self._scoring_ms += elapsed_ms
[docs] def record_custom_metric(self, key: str, value: Any) -> None: """Record a custom metric.""" self._custom_metrics[key] = value
[docs] def get_metrics(self) -> TelemetryData: """Get collected telemetry data.""" cache_metrics = CacheMetrics(hits=self._cache_hits, misses=self._cache_misses) mask_metrics = MaskMetrics( generated=self._masks_generated, unique=self._masks_unique, invalid=self._masks_invalid, ) timing_metrics = TimingMetrics( sampling_ms=self._sampling_ms, dedup_ms=self._dedup_ms, masking_ms=self._masking_ms, model_ms=self._model_ms, scoring_ms=self._scoring_ms, ) return TelemetryData( cache_metrics=cache_metrics, mask_metrics=mask_metrics, timing_metrics=timing_metrics, custom_metrics=self._custom_metrics.copy(), )
[docs] def to_json(self) -> str: """Serialize collected metrics to JSON string.""" data = self.get_metrics() return dumps(data.to_dict(), indent=2)
[docs] def reset(self) -> None: """Reset collected metrics.""" self._cache_hits = 0 self._cache_misses = 0 self._masks_generated = 0 self._masks_unique = 0 self._masks_invalid = 0 self._sampling_ms = 0.0 self._dedup_ms = 0.0 self._masking_ms = 0.0 self._model_ms = 0.0 self._scoring_ms = 0.0 self._custom_metrics.clear()
[docs] class TelemetryProbe: """Main interface for collecting telemetry during SHAP computation.""" def __init__(self, sink: ProbeSink | None = None) -> None: """ Initialize TelemetryProbe. Args: sink: Optional ProbeSink to collect metrics. If None, no metrics collected. """ self.sink = sink
[docs] def cache_operation(self, is_hit: bool) -> None: """Record a cache operation (hit or miss).""" if self.sink: self.sink.record_cache_operation(is_hit)
[docs] def mask_generated(self, is_unique: bool, is_invalid: bool = False) -> None: """Record a generated mask and whether it was unique/invalid.""" if self.sink: self.sink.record_mask_generated(is_unique, is_invalid)
[docs] def record_timing(self, stage: str, elapsed_ms: float) -> None: """Record timing for a stage (sampling, dedup, masking, model, scoring).""" if self.sink: self.sink.record_timing(stage, elapsed_ms)
[docs] def timing(self, stage: str) -> StageTimer: """Create a context manager for timing a stage.""" return StageTimer(self, stage)
[docs] def custom_metric(self, key: str, value: Any) -> None: """Record a custom metric.""" if self.sink: self.sink.record_custom_metric(key, value)
[docs] def get_metrics(self) -> TelemetryData | None: """Get collected telemetry data, or None if no sink.""" if self.sink: return self.sink.get_metrics() return None
[docs] def reset(self) -> None: """Reset collected metrics.""" if self.sink: self.sink.reset()
[docs] @staticmethod def noop() -> "TelemetryProbe": """Create a no-op probe (no sink, no collection).""" return TelemetryProbe(sink=None)
[docs] @staticmethod def with_log_sink(verbose: bool = False) -> "TelemetryProbe": """Create a probe with log sink.""" return TelemetryProbe(sink=LogProbeSink(verbose=verbose))
[docs] @staticmethod def with_json_sink() -> "TelemetryProbe": """Create a probe with JSON sink.""" return TelemetryProbe(sink=JSONProbeSink())