"""Base class for SHAP explainers using approximation methods."""
from abc import ABC
from logging import Logger
from typing import Any
import torch
from torch import Tensor
from ...utils.logger import get_logger
from .shap_explainer import BaseShapExplainer
logger: Logger = get_logger(__name__)
[docs]
class BaseShapApproximation(BaseShapExplainer, ABC):
"""
Base class for SHAP explainers using approximation methods.
"""
num_samples: int | None
"""
Number of random masks to generate. If None, uses fraction.
-1 stands for minimal number of samples (only single-feature masks and empty mask).
"""
fraction: float | None
"""Fraction of total possible masks to generate if num_samples is None."""
include_minimal_masks: bool = True
"""Whether to include minimal masks (single-feature and empty masks) in the sampling."""
_zero_mask_skipped: bool
"""Indicates if the zero mask was skipped."""
_base_masks: Tensor | None
"""Holds the base masks if :attr:`include_minimal_masks` is True."""
_base_calls_num: int
"""Number of base masks already generated."""
def __init__(
self,
*args: Any,
num_samples: int | None = None,
fraction: float = 0.6,
**kwargs: Any,
) -> None:
"""
Args:
num_samples: Number of random masks to generate. If None, uses fraction.
fraction: Fraction of total possible masks to generate if num_samples is None.
"""
super().__init__(*args, **kwargs)
BaseShapApproximation._validate_sampling_params(
num_samples=num_samples,
fraction=fraction,
)
self.num_samples = num_samples
self.fraction = fraction
def _initialize_state(self) -> None:
"""
Initialize internal state before starting mask generation.
"""
super()._initialize_state()
self._zero_mask_skipped = False
self._base_masks = None
self._base_calls_num = 0
def _get_next_split_base(
self,
n: int,
device: torch.device,
generated_masks_num: int,
existing_masks: list[Tensor] | None = None,
) -> Tensor | None:
"""
Get the next mask split for SHAP value calculation
from the base minimal masks, if applicable.
Args:
n: Length of the masks
device: Torch device to create the tensor on
generated_masks_num: Number of masks already generated
Returns:
Next mask tensor or None if no more masks can be generated.
Raises:
RuntimeError: If there are inconsistencies in mask generation logic.
"""
if self.include_minimal_masks:
if generated_masks_num == 0:
if self._first_call:
self._base_masks = BaseShapApproximation._generate_minimal_splits(
n=n,
device=device,
)
if self._base_masks is None:
return None
self._first_call = False
elif (
not self._zero_mask_skipped
): # 0 mask was rejected, so start from 1
# base masks here cannot be None
self._base_masks = self._base_masks[1:]
self._zero_mask_skipped = True
else: # another mask was rejected, raise
raise RuntimeError("Multiple base masks were rejected.")
if self._base_masks is None:
raise RuntimeError("Base masks are not present.")
num_splits = self._get_num_splits(n)
if num_splits is not None and num_splits < self._base_masks.shape[0]:
raise RuntimeError(
f"Not enough sampling budget, up to {num_splits} "
f"calls allowed with required {self._base_masks.shape[0]} for minimal masks."
)
if generated_masks_num < self._base_masks.shape[0]:
if self._base_calls_num != generated_masks_num + int(
self._zero_mask_skipped
):
raise RuntimeError("Multiple base masks were rejected.")
self._base_calls_num += 1
return self._base_masks[generated_masks_num, ...].squeeze(0)
return None
def _get_next_split(
self,
n: int,
device: torch.device,
generated_masks_num: int,
existing_masks: list[Tensor] | None = None,
) -> Tensor | None:
r = self._get_next_split_base(
n=n,
device=device,
generated_masks_num=generated_masks_num,
existing_masks=existing_masks,
)
self._first_call = False
if r is not None: # if base mask was generated
return r
if generated_masks_num < self._get_num_splits(n=n):
return self._get_random_split(n=n, device=device)
return None
@staticmethod
def _generate_minimal_splits(n: int, device: torch.device) -> torch.Tensor:
"""
Generate a minimal set of boolean masks as a batched tensor.
Shape: (n + 1, n)
"""
masks = torch.ones((n + 1, n), dtype=torch.bool, device=device)
masks[0, :] = False
masks[torch.arange(1, n + 1), torch.arange(n)] = False
return masks
@staticmethod
def _get_random_split(
n: int,
device: torch.device,
true_values_num: int | None = None,
include_token: int | None = None,
) -> Tensor:
"""
Generate a random split mask of shape [1, n].
Args:
n: Length of the mask
device: The device to create the mask on
true_values_num: Optional number of True values in the mask
include_token: Optional index of a token that must be included in the mask
Returns:
Tensor of shape [1, n], dtype=torch.bool, representing the random split mask.
"""
if true_values_num is None:
return torch.randint(0, 2, (1, n), dtype=torch.bool, device=device)
# one token is already included
if include_token is not None:
n -= 1
true_values_num -= 1
mask = torch.zeros((1, n), dtype=torch.bool, device=device)
true_indices = torch.randperm(n, device=device)[:true_values_num]
mask[0, true_indices] = True
if include_token is not None:
new_mask = torch.zeros((1, n + 1), dtype=torch.bool, device=device)
new_mask[..., include_token] = True
new_mask[~new_mask] = mask
mask = new_mask
return mask
@staticmethod
def _validate_sampling_params(
num_samples: int | None,
fraction: float | None,
) -> None:
"""
Validate sampling parameters for SHAP approximation.
Args:
num_samples: Number of samples to generate.
fraction: Fraction of total possible samples to generate.
Raises:
ValueError: If both parameters are None or invalid.
"""
if num_samples is None and fraction is None:
raise ValueError("Either num_samples or fraction must be provided.")
if fraction is not None and (
not isinstance(fraction, float) or not 0 < fraction <= 1
):
raise ValueError("fraction must be a float in the range (0, 1].")
if num_samples is not None and (
not isinstance(num_samples, int) or (num_samples <= 0 and num_samples != -1)
):
raise ValueError("num_samples must be a positive integer.")