Source code for mllm_shap.shap.core.engine

"""Sampling engines extracted from the explainer inheritance flow."""

from contextlib import nullcontext
from dataclasses import dataclass
from logging import Logger
from time import perf_counter
from typing import TYPE_CHECKING, Generator

import torch
from torch import Tensor

from ...utils.logger import get_logger
from ..base._mask_generator import MaskGenerator
from ..base._masks_manager import MasksManager
from .contracts import SamplingStrategy

if TYPE_CHECKING:
    from .telemetry import TelemetryProbe

logger: Logger = get_logger(__name__)


[docs] @dataclass class SamplingStats: """Sampling runtime and filtering counters.""" candidate_splits: int = 0 """Number of candidate splits proposed by the sampling strategy.""" yielded_masks: int = 0 """Number of masks yielded to the response-generation pipeline.""" skipped_full_or_empty: int = 0 """Number of masks skipped because they were full or empty.""" skipped_invalid_masks: int = 0 """Number of masks skipped because they produced invalid prepared masks.""" skipped_duplicates: int = 0 """Number of masks skipped because they duplicated an already seen mask.""" elapsed_ms: float = 0.0 """Total generator runtime in milliseconds."""
[docs] class SamplingEngine: """Generic mask-generation engine driven by a sampling strategy.""" def __init__( self, strategy: SamplingStrategy, allow_mask_duplicates: bool = False, allow_full_or_empty: bool = False, probe: "TelemetryProbe | None" = None, ) -> None: """Initialize the sampling engine. Args: strategy: Strategy responsible for proposing the next split. allow_mask_duplicates: Whether duplicate masks may be yielded. allow_full_or_empty: Whether all-zero and all-one masks may be yielded. probe: Optional telemetry probe used to record timing metrics. """ self._strategy = strategy self._allow_mask_duplicates = allow_mask_duplicates self._allow_full_or_empty = allow_full_or_empty self._probe = probe
[docs] def create_generator( self, mask_manager: MasksManager, device: torch.device, masks: list[Tensor], ) -> MaskGenerator: """Create a mask generator compatible with existing SHAP flow.""" num_splits = self._strategy.get_num_splits(mask_manager.n) strategy = self._strategy allow_mask_duplicates = self._allow_mask_duplicates allow_full_or_empty = self._allow_full_or_empty probe = self._probe if self._probe is not None and self._probe.sink else None stats = SamplingStats() class _MasksGenerator(MaskGenerator): """Generator class for masks.""" def __init__(self) -> None: """Initialize the mask generator and expose shared sampling stats.""" super().__init__() self.stats = stats def _mask_iter(self) -> Generator[tuple[Tensor | None, int], None, None]: t0 = perf_counter() try: while True: with probe.timing("sampling") if probe else nullcontext(): new_split = strategy.get_next_split( n=mask_manager.n, device=device, generated_masks_num=self.generated_masks, existing_masks=masks, ) if new_split is None: break self.stats.candidate_splits += 1 if not allow_full_or_empty and ( not new_split.any() or new_split.all() ): logger.debug("Generated zero or all-ones mask, skipping.") self.stats.skipped_full_or_empty += 1 continue with probe.timing("masking") if probe else nullcontext(): new_mask = mask_manager.prepare_mask( split=new_split, device=device, ) if new_mask is None: logger.debug("Generated mask has no True values, skipping.") self.stats.skipped_invalid_masks += 1 continue with probe.timing("dedup") if probe else nullcontext(): new_mask_hash = mask_manager.get_hash(new_mask) if not allow_mask_duplicates and mask_manager.seen( mask_hash=new_mask_hash ): logger.debug("Generated duplicate mask, skipping.") self.stats.skipped_duplicates += 1 continue mask_manager.mark_seen(mask_hash=new_mask_hash) self.generated_masks += 1 self.stats.yielded_masks += 1 yield new_mask, new_mask_hash finally: self.stats.elapsed_ms = (perf_counter() - t0) * 1000.0 def __len__(self) -> int | None: """Return the expected number of splits when it is known.""" return num_splits return _MasksGenerator()