Source code for mllm_shap.shap.core.estimation

"""Composable estimation and stopping adapters for SHAP pipelines."""

from collections.abc import Callable

from torch import Tensor

from .contracts import EstimationResult, Estimator, StopDecision, StoppingPolicy


[docs] class CallableEstimator(Estimator): """Adapter that bridges callables to the Estimator contract.""" def __init__( self, estimate_fn: Callable[[Tensor, Tensor], Tensor | tuple[Tensor, Tensor]], ) -> None: """Initialize callable-backed estimator. Args: estimate_fn: Callable receiving ``(masks, payoffs)`` and returning either: - attribution tensor, or - tuple ``(attributions, uncertainty)``. """ self._estimate_fn = estimate_fn
[docs] def estimate(self, masks: Tensor, payoffs: Tensor) -> EstimationResult: """Delegate estimation to wrapped callable.""" result = self._estimate_fn(masks, payoffs) if isinstance(result, tuple): values, uncertainty = result return EstimationResult(values=values, uncertainty=uncertainty) return EstimationResult(values=result, uncertainty=None)
[docs] class CallableStoppingPolicy(StoppingPolicy): """Adapter that bridges callables to the StoppingPolicy contract.""" def __init__( self, should_stop_fn: Callable[[int, EstimationResult], bool | StopDecision], default_reason: str = "callable-policy", ) -> None: """Initialize callable-backed stopping policy.""" self._should_stop_fn = should_stop_fn self._default_reason = default_reason
[docs] def should_stop(self, iteration: int, estimation: EstimationResult) -> StopDecision: """Delegate stopping decision to wrapped callable.""" decision = self._should_stop_fn(iteration, estimation) if isinstance(decision, StopDecision): return decision return StopDecision( should_stop=bool(decision), reason=self._default_reason, )
[docs] class FixedThresholdStoppingPolicy(StoppingPolicy): """Stop when scalar uncertainty falls below a configured threshold.""" def __init__(self, threshold: float) -> None: """Initialize fixed-threshold policy. Args: threshold: Uncertainty threshold at or below which stopping is triggered. """ self._threshold = float(threshold)
[docs] def should_stop(self, iteration: int, estimation: EstimationResult) -> StopDecision: """Apply fixed uncertainty threshold rule.""" if estimation.uncertainty is None: return StopDecision( should_stop=False, reason="uncertainty-missing", ) uncertainty_value = float(estimation.uncertainty.detach().mean().item()) should_stop = uncertainty_value <= self._threshold reason = ( f"uncertainty<={self._threshold}" if should_stop else "uncertainty-above-threshold" ) return StopDecision(should_stop=should_stop, reason=reason)