Source code for mllm_shap.shap.core.sampling
"""Sampling strategy adapters used by the composition-based sampling engines."""
from collections.abc import Callable
from logging import Logger
import torch
from torch import Tensor
from ...utils.logger import get_logger
from .contracts import SamplingStrategy
logger: Logger = get_logger(__name__)
[docs]
class CallableAdapterStrategy(SamplingStrategy):
"""Adapter that bridges callable functions to SamplingStrategy contract."""
def __init__(
self,
get_next_split: Callable[
[int, torch.device, int, list[Tensor] | None],
Tensor | None,
],
get_num_splits: Callable[[int], int | None],
) -> None:
"""Initialize the callable-backed sampling strategy.
Args:
get_next_split: Callable used to generate the next split mask.
get_num_splits: Callable used to report the expected number of splits.
"""
self._get_next_split = get_next_split
self._get_num_splits = get_num_splits
[docs]
def get_next_split(
self,
n: int,
device: torch.device,
generated_masks_num: int,
existing_masks: list[Tensor] | None = None,
) -> Tensor | None:
"""Delegate split generation to the wrapped callable.
Args:
n: Number of explainable features.
device: Device on which the split should be created.
generated_masks_num: Number of masks generated so far.
existing_masks: Existing masks available to the strategy.
Returns:
The next split mask, or ``None`` when generation should stop.
"""
return self._get_next_split(
n=n,
device=device,
generated_masks_num=generated_masks_num,
existing_masks=existing_masks,
)
[docs]
def get_num_splits(self, n: int) -> int | None:
"""Return the expected number of splits for ``n`` explainable features."""
return self._get_num_splits(n)