Source code for mllm_shap.shap.pipeline.pipeline
"""Generic explain pipeline executor."""
from dataclasses import dataclass
from time import perf_counter
from uuid import uuid4
from ..core.telemetry import TelemetryProbe
from ...observability.events import TraceEvent
from ...observability.sink import ObservabilitySink
from .context import ExplainContext, ExplainState
from .contracts import PipelineStage
[docs]
@dataclass(frozen=True)
class ExplainPipeline:
"""Ordered stage executor for SHAP runs."""
stages: tuple[PipelineStage, ...]
"""Ordered tuple of pipeline stages to execute. Each stage should be
an instance of a class that implements the PipelineStage protocol,
meaning it should have a run method that takes
an ExplainContext, ExplainState, and optional TelemetryProbe as"""
[docs]
def run(
self,
context: ExplainContext,
state: ExplainState,
probe: TelemetryProbe | None = None,
) -> ExplainState:
"""Execute stages in order and return mutated state."""
sink = context.params.get("observability_sink")
bridge = None
if isinstance(sink, ObservabilitySink):
from ...observability.bridge import TelemetryBridge
run_id = str(
state.metadata.get("run_id") or context.params.get("run_id") or uuid4()
)
state.add_metadata("run_id", run_id)
bridge = TelemetryBridge.from_probe(probe=probe, run_id=run_id, sink=sink)
for idx, stage in enumerate(self.stages):
stage_name = type(stage).__name__
if bridge is not None:
bridge.sink.emit_event(
TraceEvent(
name="stage_start",
run_id=bridge.run_id,
attrs={"stage": stage_name, "index": idx},
)
)
started = perf_counter()
try:
stage.run(context=context, state=state, probe=probe)
except Exception as exc:
elapsed_ms = (perf_counter() - started) * 1000.0
if bridge is not None:
bridge.sink.emit_event(
TraceEvent(
name="stage_error",
run_id=bridge.run_id,
attrs={
"stage": stage_name,
"index": idx,
"elapsed_ms": elapsed_ms,
"error": str(exc),
"error_type": type(exc).__name__,
},
)
)
raise
elapsed_ms = (perf_counter() - started) * 1000.0
if bridge is not None:
bridge.stage_span(stage=stage_name, elapsed_ms=elapsed_ms)
bridge.sink.emit_event(
TraceEvent(
name="stage_end",
run_id=bridge.run_id,
attrs={
"stage": stage_name,
"index": idx,
"elapsed_ms": elapsed_ms,
},
)
)
return state