Source code for mllm_shap.shap.core.contracts
"""Core sampling contracts for composition-based SHAP internals."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
from torch import Tensor
[docs]
class SamplingStrategy(ABC):
"""Contract for split-sampling strategies."""
[docs]
@abstractmethod
def get_next_split(
self,
n: int,
device: torch.device,
generated_masks_num: int,
existing_masks: list[Tensor] | None = None,
) -> Tensor | None:
"""Return next split or None when strategy is exhausted."""
[docs]
@abstractmethod
def get_num_splits(self, n: int) -> int | None:
"""Return expected number of generated masks, if known."""
[docs]
@dataclass(frozen=True)
class EstimationResult:
"""Container for estimator outputs.
Attributes:
values: Estimated attribution values.
uncertainty: Optional uncertainty proxy (for example CI width or variance).
``None`` means uncertainty is not provided by this estimator.
"""
values: Tensor
"""Estimated attribution values returned by the estimator."""
uncertainty: Tensor | None = None
"""Optional uncertainty proxy aligned with ``values`` when available."""
[docs]
class Estimator(ABC):
"""Contract for composable attribution estimators."""
[docs]
@abstractmethod
def estimate(self, masks: Tensor, payoffs: Tensor) -> EstimationResult:
"""Estimate attributions from sampled masks and corresponding payoffs."""
[docs]
@dataclass(frozen=True)
class StopDecision:
"""Decision returned by stopping policies."""
should_stop: bool
"""Whether the iterative estimation loop should terminate now."""
reason: str
"""Human-readable explanation for the stop/continue decision."""
[docs]
class StoppingPolicy(ABC):
"""Contract for composable stopping policies."""
[docs]
@abstractmethod
def should_stop(self, iteration: int, estimation: EstimationResult) -> StopDecision:
"""Return whether pipeline should stop and why."""