Source code for mllm_shap.shap.base.complementary

# pylint: disable=invalid-name
"""Base class for SHAP explainers using approximation methods."""

from abc import ABC
from functools import lru_cache
from logging import Logger
from typing import Any, Generator

import torch
from torch import Tensor

from ...utils.logger import get_logger
from ._mask_generator import MaskGenerator
from ._masks_manager import MasksManager
from .approx import BaseShapApproximation

logger: Logger = get_logger(__name__)


# pylint: disable=too-few-public-methods
[docs] class BaseComplementaryShapApproximation(BaseShapApproximation, ABC): """Complementary SHAP implementation class.""" _M: Tensor | None """ Matrix M used in Complementary calculations - number of times feature i appears in coalitions of size j+1. """ _C: Tensor | None """ Matrix C used in Complementary calculations - C[i, j] = sum of complementary contributions for feature i in coalitions of size j+1. """ @lru_cache(maxsize=1) def _get_num_splits(self, n: int) -> int: return BaseComplementaryShapApproximation._get_num_splits_static( n=n, num_samples=self.num_samples, fraction=self.fraction, include_minimal_masks=self.include_minimal_masks, ) def _initialize_state(self) -> None: super()._initialize_state() self._get_num_splits.cache_clear() self._zero_mask_skipped = True # this algorithm cannot use zero mask self._M = None self._C = None # pylint: disable=too-many-arguments,too-many-positional-arguments,duplicate-code def _get_masks_generator( self, mask_manager: MasksManager, device: torch.device, masks: list[Tensor], allow_full_or_empty: bool = False, ) -> MaskGenerator: n = mask_manager.n num_splits = self._get_num_splits(mask_manager.n) get_next_split = self._get_next_split allow_mask_duplicates = self.allow_mask_duplicates # Initialize M matrix if self._M is None: self._M = torch.zeros((n, n + 1), dtype=torch.int16, device=device) M = self._M # We can generate only pairs --> no space for zero mask # that will be a pair to existing all-ones mask if self._get_num_splits(mask_manager.n) % 2 == 0: self._zero_mask_skipped = True class _MasksGenerator(MaskGenerator): """Generator class for masks.""" def __init__(self) -> None: """Initialize the MaskGenerator.""" super().__init__() self._next_result: tuple[Tensor | None, int] | None = None def _mask_iter(self) -> Generator[tuple[Tensor | None, int], None, None]: while True: if self._next_result is not None: yield self._next_result self._next_result = None continue new_split = get_next_split( n=mask_manager.n, device=device, generated_masks_num=self.generated_masks, existing_masks=masks, ) if new_split is None: break coalition_size = int(new_split.sum().item()) if not allow_full_or_empty and (not new_split.any() or new_split.all()): logger.debug( "Generated zero or all-ones mask of size %d, skipping.", coalition_size, ) continue new_split_neg = ~new_split new_mask = mask_manager.prepare_mask(split=new_split, device=device) new_mask_neg = mask_manager.prepare_mask(split=new_split_neg, device=device) if new_mask is None or new_mask_neg is None: logger.info( "Generated mask of size %d (or its negation) has no True values, skipping.", coalition_size, ) continue new_mask_hash = mask_manager.get_hash(new_mask) new_mask_neg_hash = mask_manager.get_hash(new_mask_neg) if not allow_mask_duplicates: if mask_manager.seen(mask_hash=new_mask_hash) or mask_manager.seen(mask_hash=new_mask_neg_hash): logger.debug( "Generated duplicate mask of size %d, skipping.", coalition_size, ) continue mask_manager.mark_seen(mask_hash=new_mask_hash) mask_manager.mark_seen(mask_hash=new_mask_neg_hash) self.generated_masks += 2 for split in (new_split, new_split_neg): coalition_size = int(split.sum().item()) logger.debug( "new_split: %s, coalition_size: %d", split.squeeze(0), coalition_size, ) BaseComplementaryShapApproximation._increment_coalition_val( # pylint: disable=protected-access M, split.squeeze(0), coalition_size, 1 ) self._next_result = (new_mask_neg, new_mask_neg_hash) yield new_mask, new_mask_hash def __len__(self) -> int | None: return num_splits return _MasksGenerator() def _calculate_C_matrix(self, masks: Tensor, similarities: Tensor, device: torch.device) -> None: """ Calculate the C matrix used in Complementary SHAP calculations. Args: masks: Tensor of shape [m, n] representing the generated masks. similarities: Tensor of shape [m, ] representing the similarities for each mask. device: The device to perform calculations on. Raises: ValueError: If masks are not in complementary pairs. RuntimeError: If M matrix is not initialized. """ if self._M is None: raise RuntimeError("M matrix must be initialized before calculating C matrix.") if self._C is None: self._C = torch.zeros_like(self._M, dtype=similarities.dtype, device=device) m = masks.shape[0] // 2 if 2 * m != masks.shape[0]: raise ValueError("Masks should be in complementary pairs.") for i in range(m): if not torch.all(masks[2 * i] == ~masks[2 * i + 1]): raise ValueError("Masks are not complementary pairs.") S = masks[2 * i] NS = masks[2 * i + 1] s_size = int(S.sum().item()) ns_size = masks.shape[1] - s_size u = similarities[2 * i] - similarities[2 * i + 1] BaseComplementaryShapApproximation._increment_coalition_val(self._C, S, s_size, u) BaseComplementaryShapApproximation._increment_coalition_val(self._C, NS, ns_size, -u) @staticmethod def _get_num_splits_static( n: int, num_samples: int | None = None, fraction: float | None = None, force_minimal: bool = True, include_minimal_masks: bool = False, ) -> int: if num_samples is not None: if num_samples == -1: if include_minimal_masks: # Minimal: pairs of single-feature masks return 2 * n raise ValueError("num_samples cannot be -1 when include_minimal_masks is False.") if force_minimal and num_samples < 2 * n: raise ValueError("num_samples must be at least equal to the number of features times two.") if num_samples > (2**n - 2): return int(2**n - 2) # maximum possible masks excluding all-ones and all-zeros mask if num_samples % 2 == 1: raise ValueError("num_samples must not be odd to account for complementary masks (in pairs).") return num_samples # use fraction total_masks = int(2**n - 2) # exclude all-ones and all-zeros mask r = int(total_masks * fraction) # type: ignore[operator] if r < 2 * n: r = 2 * n # minimal: pairs of single-feature masks logger.warning( ( "Calculated number of samples (%d) is less than " "minimal required (%d). Using minimal number of samples." ), r, 2 * n, ) if r % 2 == 0: return r return r - 1 # ensure even number of samples @staticmethod def _increment_coalition_val(tensor: Tensor, indices: Tensor, coalition_size: int, value: Any) -> None: """ Increment the value in the tensor for the given coalition. If coalition_size is 0, update the first column. Args: tensor: The tensor to update. indices: The indices of the features in the coalition. coalition_size: The size of the coalition. value: The value to add. """ if coalition_size == 0: tensor[:, 0] += value else: tensor[indices, coalition_size] += value