Source code for mllm_shap.shap.precise
"""Precise SHAP explainer implementation."""
from itertools import product
from typing import Generator
import torch
from torch import Tensor
from .base.shap_explainer import BaseShapExplainer
# pylint: disable=too-few-public-methods
[docs]
class PreciseShapExplainer(BaseShapExplainer):
"""Precise SHAP implementation generating all possible masks."""
__splits_generator: Generator[Tensor, None, None] | None = None
def _get_next_split(
self,
n: int,
device: torch.device,
generated_masks_num: int,
existing_masks: list[Tensor] | None = None,
) -> Tensor | None:
if self._first_call:
self._first_call = False
self.__splits_generator = PreciseShapExplainer.__get_splits_generator(
n=n,
device=device,
)
if self.__splits_generator is None:
raise RuntimeError("Splits generator is not present.")
try:
return next(self.__splits_generator)
except StopIteration:
return None
def _get_num_splits(self, n: int) -> int:
return int(2**n - 1) # exclude all-true mask
# pylint: disable=too-many-locals
def _calculate_shap_values(
self,
masks: Tensor,
similarities: Tensor,
device: torch.device,
) -> Tensor:
num_features = masks.shape[1]
shap_values = torch.zeros(num_features, dtype=similarities.dtype, device=device)
# Precompute factorial terms for efficiency
# using formula a! = (a - 1)! * a
indices = torch.arange(num_features + 1, dtype=torch.float32, device=device)
indices[0] = 1.0
factorials = torch.cumprod(indices, dim=0)
# Precompute hash values for all subsets
subset_hashes = (masks * (2 ** torch.arange(num_features, device=device))).sum(dim=1)
sorted_hashes, sort_idx = subset_hashes.sort()
sorted_outputs = similarities[sort_idx]
# Precompute subset sizes
subset_sizes = masks.sum(dim=1)
# formula: \phi_i = \sum_{S ⊆ N \ {i}} [ |S|! * (|N| - |S| - 1)! / |N|! * (f(S ∪ {i}) - f(S)) ]
for i in range(num_features):
# Select subsets that include feature i
include_mask = masks[:, i]
# All subsets that include i - IN = {S : i ∈ S}
included_subsets = masks[include_mask]
included_outputs = similarities[include_mask] # f(IN)
# Corresponding subsets with i removed - OUT = {S \ {i} : S ∈ IN}
excluded_subsets = included_subsets.clone()
excluded_subsets[:, i] = False
excluded_hash = (excluded_subsets * (2 ** torch.arange(num_features, device=masks.device))).sum(dim=1)
excluded_outputs = sorted_outputs[torch.searchsorted(sorted_hashes, excluded_hash)] # f(OUT)
# Corresponding subset sizes - |S| for S ∈ OUT
excluded_subset_sizes = subset_sizes[include_mask] - 1
weights = (
factorials[excluded_subset_sizes]
* factorials[num_features - excluded_subset_sizes - 1]
/ factorials[num_features]
)
shap_values[i] = torch.sum(weights * (included_outputs - excluded_outputs))
return shap_values
@staticmethod
def __get_splits_generator(n: int, device: torch.device) -> Generator[Tensor, None, None]:
"""
Generates all possible binary masks of a given length, excluding the all-ones mask.
Args:
n (int): The length of the binary masks to generate.
device (torch.device): The device on which to create the tensors.
Yields:
Tensor: A binary mask tensor of shape (1, n).
"""
for split in product([0, 1], repeat=n):
split_tensor = torch.tensor(split, dtype=torch.bool, device=device)
if split_tensor.sum() == n:
continue
yield split_tensor