Source code for mllm_shap.shap.base.shap_explainer

"""Base class for SHAP-based explanations."""

import gc
from abc import ABC, abstractmethod
from logging import Logger
from typing import Any, Generator

import torch
from torch import Tensor

from ...utils.other import extend_tensor
from ...connectors.base.explainer_cache import ExplainerCache
from ...connectors.base.model import BaseMllmModel
from ...connectors.base.model_response import ModelResponse
from ...connectors.base.chat import BaseMllmChat
from ...utils.logger import get_logger
from ..embeddings import MeanReducer
from ..enums import Mode
from ..normalizers import PowerShiftNormalizer
from ..similarity import CosineSimilarity
from ._validators import BaseShapCallConfig, BaseShapConfig
from .embeddings import BaseEmbeddingReducer, BaseExternalEmbedding
from .normalizers import BaseNormalizer
from .similarity import BaseEmbeddingSimilarity
from ._masks_manager import MasksManager
from ._mask_generator import MaskGenerator
from ._cache_manager import CacheManager
from ._generate_responses import generate_responses

logger: Logger = get_logger(__name__)


[docs] class NotEnoughTokensToExplainError(Exception): """Raised when there are not enough tokens to explain in the chat."""
# pylint: disable=too-few-public-methods,too-many-instance-attributes
[docs] class BaseShapExplainer(ABC): """Base class for SHAP-based explanations.""" mode: Mode """The SHAP mode, either `STATIC` or `CONTEXTUAL`. Used if no :attr:`embedding_model` is provided.""" embedding_model: BaseExternalEmbedding | None """The external embedding model to use. If provided, overrides :attr:`mode`.""" embedding_reducer: BaseEmbeddingReducer """The embedding reduction strategy to use.""" similarity_measure: BaseEmbeddingSimilarity """The embedding similarity measure to use.""" normalizer: BaseNormalizer """The SHAP value normalizer to use.""" allow_mask_duplicates: bool """Whether to allow duplicate masks during generation.""" total_n_calls: int = 0 """Total number of MLLM calls made for last explanation.""" _first_call: bool """Indicates if it's the first call to generate masks.""" # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, mode: Mode = Mode.CONTEXTUAL, embedding_model: BaseExternalEmbedding | None = None, embedding_reducer: BaseEmbeddingReducer | None = None, similarity_measure: BaseEmbeddingSimilarity | None = None, normalizer: BaseNormalizer | None = None, allow_mask_duplicates: bool = False, ): """ Initialize the SHAP base class. Args: mode: The SHAP mode, either STATIC or CONTEXTUAL. Used if no embedding_model is provided. embedding_model: The external embedding model to use. If provided, overrides mode. embedding_reducer: The embedding reduction strategy to use. Defaults to MeanReducer. similarity_measure: The embedding similarity measure to use. Defaults to CosineSimilarity. normalizer: The SHAP value normalizer to use. Defaults to PowerShiftNormalizer. allow_mask_duplicates: Whether to allow duplicate masks during generation. """ # validation __config = BaseShapConfig( mode=mode, embedding_model=embedding_model, embedding_reducer=embedding_reducer if embedding_reducer is not None else MeanReducer(), similarity_measure=similarity_measure if similarity_measure is not None else CosineSimilarity(), normalizer=normalizer if normalizer is not None else PowerShiftNormalizer(), allow_mask_duplicates=allow_mask_duplicates, ) self.mode = __config.mode self.embedding_model = __config.embedding_model self.embedding_reducer = __config.embedding_reducer self.similarity_measure = __config.similarity_measure self.normalizer = __config.normalizer self.allow_mask_duplicates = __config.allow_mask_duplicates @abstractmethod def _get_next_split( self, n: int, device: torch.device, generated_masks_num: int, existing_masks: list[Tensor] | None = None, ) -> Tensor | None: """ Get next split to evaluate. Args: n: Length of the splits device: The device to create the masks on generated_masks_num: Number of masks generated so far existing_masks: List of existing masks Returns: Tensor of shape [1, n], dtype=torch.bool, representing the next split to evaluate or None if no more splits are to be generated. """ @abstractmethod def _get_num_splits(self, n: int) -> int: """ Determine the number of masks to generate based on num_samples and fraction. Args: n: Length of the splits Returns: Number of masks to generate. """ @abstractmethod def _calculate_shap_values( self, masks: Tensor, similarities: Tensor, device: torch.device, ) -> Tensor: """ Calculate SHAP values based on similarity between base and masked embeddings. Args: masks (Tensor): 2D boolean tensor [num_masks, num_tokens], each row indicates which tokens are included in that mask. The first mask (index 0) represents the base mask with all tokens included (all True values). similarities (Tensor): 1D tensor [num_masks], similarity score for each mask. device: The device to create the SHAP values on. Returns: Tensor: 1D tensor [num_tokens] with SHAP values (NaN where base_mask=False). """ def _initialize_state(self) -> None: """ Initialize internal state before starting mask generation. """ self.total_n_calls = 0 self._first_call = True def _get_masks_generator( self, mask_manager: MasksManager, device: torch.device, masks: list[Tensor], ) -> MaskGenerator: """ Generator that yields masks one by one. Args: mask_manager: The masks manager instance. device: The device to create the masks on. masks: List of existing masks. Returns: A generator yielding tuples of (mask, mask_hash). """ num_splits = self._get_num_splits(mask_manager.n) get_next_split = self._get_next_split allow_mask_duplicates = self.allow_mask_duplicates class _MasksGenerator(MaskGenerator): """Generator class for masks.""" def _mask_iter(self) -> Generator[tuple[Tensor | None, int], None, None]: while True: new_split = get_next_split( n=mask_manager.n, device=device, generated_masks_num=self.generated_masks, existing_masks=masks, ) if new_split is None: break if not new_split.any() or new_split.all(): logger.debug("Generated zero or all-ones mask, skipping.") continue new_mask = mask_manager.prepare_mask(split=new_split, device=device) if new_mask is None: logger.debug("Generated mask has no True values, skipping.") continue new_mask_hash = mask_manager.get_hash(new_mask) if not allow_mask_duplicates and mask_manager.seen(mask_hash=new_mask_hash): logger.debug("Generated duplicate mask, skipping.") continue mask_manager.mark_seen(mask_hash=new_mask_hash) self.generated_masks += 1 yield new_mask, new_mask_hash def __len__(self) -> int | None: return num_splits return _MasksGenerator() def _generate_step( self, mask_manager: MasksManager, device: torch.device, masks: list[Tensor], **generate_kwargs: Any, ) -> tuple[int, list[tuple[Tensor, int, BaseMllmChat | None, ModelResponse]] | None]: """ Generate a step of masks and get model responses. Args: mask_manager: The masks manager instance. device: The device to create the masks on. masks: List of existing masks. generate_kwargs: Additional keyword arguments for the model's generate method. Returns: A tuple containing: - Number of chats skipped due to being empty. - History of chats and masks used during explanation. """ gen = self._get_masks_generator(mask_manager=mask_manager, device=device, masks=masks) r = generate_responses( masks=masks, gen=gen, **generate_kwargs, ) # retrieve generated masks from the generator self.total_n_calls = gen.generated_masks return r def _get_similarities(self, responses: list[ModelResponse], model: BaseMllmModel) -> Tensor: """ Get similarities between the base response and other responses. Args: responses: The model responses to compare. model: The model instance. Returns: A tensor containing the similarities. """ if self.similarity_measure.operates_on_embeddings: # get embeddings for the response embeddings = self.__get_embeddings( responses=responses, model=model, ) # calculate similarities between original response embeddings return self.similarity_measure(base=embeddings[0], other=embeddings) # If not operating on embeddings, handle raw responses return self.similarity_measure(base=responses[0], other=responses) # pylint: disable=too-many-locals def _get_shap_values( self, model: BaseMllmModel, masks: Tensor, responses: list[ModelResponse], source_chat: BaseMllmChat, device: torch.device, similarities: Tensor | None = None, ) -> tuple[Tensor, Tensor]: """ Get SHAP values for the given mask. Args: model: The model instance. masks: 2D boolean tensor [num_masks, num_tokens], each row indicates which tokens are included in that mask. responses: The model responses corresponding to the masks. source_chat: The source chat instance used to search for external group ids. device: The device to create the SHAP values on. similarities: Precomputed similarities between base and masked responses. If None, will be computed. Returns: A tuple containing: - The calculated SHAP values. - The normalized SHAP values. """ shap_values_mask = source_chat.shap_values_mask if similarities is None: similarities = self._get_similarities(responses=responses, model=model) # Pre-allocate SHAP values with NaNs shap_values = torch.full_like( shap_values_mask, float("nan"), device=device, dtype=similarities.dtype, ) # Calculate SHAP values only for relevant parts calculated_shap_values = self._calculate_shap_values( masks=masks[..., shap_values_mask], # only pass relevant parts of masks similarities=similarities, device=device, ) shap_values[shap_values_mask] = calculated_shap_values # Normalize only calculated SHAP values normalized_shap_values = shap_values.clone() normalized_shap_values[shap_values_mask] = self.normalizer(calculated_shap_values) # duplicate if external group ids are used if source_chat.external_group_ids is not None: for group_id, group_shap_value, group_normalized_shap_value in zip( source_chat.external_group_ids[source_chat.external_group_ids_first_positions], shap_values[source_chat.external_group_ids_first_positions], normalized_shap_values[source_chat.external_group_ids_first_positions], ): mask = source_chat.external_group_ids == group_id shap_values[mask] = group_shap_value normalized_shap_values[mask] = group_normalized_shap_value return shap_values, normalized_shap_values def _save_to_cache( self, chat: BaseMllmChat, source_chat: BaseMllmChat, responses: list[ModelResponse], masks: Tensor, shap_values: Tensor, normalized_shap_values: Tensor, ) -> None: """ Save the SHAP explainer cache in the full chat. Args: chat: The chat instance to save the cache for. source_chat: The original chat instance from which SHAP values were derived. responses: The model responses used for SHAP calculations. masks: The masks used for SHAP calculations. shap_values: The SHAP values calculated. normalized_shap_values: The normalized SHAP values calculated. Raises: ValueError: If cache already exists for the provided chat. """ logger.debug("Saving SHAP explainer cache for chat %s.", chat) if chat.cache is not None: raise ValueError("SHAP cache already exists for the provided chat.") # translate it for reference to group ids shap_values_mask = source_chat.translate_groups_ids_mask(source_chat.shap_values_mask) # extend mask with False to match new response length shap_values_mask = extend_tensor( shap_values_mask, target_length=chat.input_tokens_num, fill_value=False, ) chat.cache = ExplainerCache.create( chat=chat, explainer_hash=hash(self), responses=responses, masks=masks, values=shap_values, normalized_values=normalized_shap_values, shap_values_mask=shap_values_mask, ) def __get_embeddings(self, responses: list[ModelResponse], model: BaseMllmModel) -> Tensor: """ Get embeddings for the given chat state. Args: responses: The model responses to get embeddings for. chat: The current chat state. Returns: The embeddings tensor. """ if self.embedding_model is not None: return self.embedding_reducer(self.embedding_model(responses=responses)) if self.mode == Mode.STATIC: return self.embedding_reducer(model.get_static_embeddings(responses=responses)) return self.embedding_reducer(model.get_contextual_embeddings(responses=responses)) # keep the logic in one method for readability # pylint: disable=too-many-arguments,too-many-positional-arguments # pylint: disable=too-many-locals,too-many-statements,too-many-branches def __call__( self, model: BaseMllmModel, source_chat: BaseMllmChat, response: ModelResponse, progress_bar: bool = True, verbose: bool = False, n_generator_jobs: int = 1, **generate_kwargs: Any, ) -> list[tuple[Tensor, int, BaseMllmChat | None, ModelResponse]] | None: """ Generate splits of the input tokens in the chat state. Args: model: The model instance. source_chat: Chat to get explained (without base response). response: The model response generated from source_chat. progress_bar: Whether to display a progress bar during processing. verbose: Whether to save data generated during processing. n_generator_jobs: Number of parallel calls to the model's generate method. generate_kwargs: Additional keyword arguments for the model's generate method. Returns: If verbose is True, returns the history of chats and masks used during explanation. History has entries of the form (mask, mask_hash, masked_chat, model_response). If cache was used, masked_chat will be None. If verbose is False, returns None. Raises: NotEnoughTokensToExplainError: If there are not enough tokens to explain after filtering out empty chats. ValueError: If existing cache is invalid. """ __config = BaseShapCallConfig( model=model, source_chat=source_chat, response=response, progress_bar=progress_bar, verbose=verbose, ) self._initialize_state() # validated within BaseShapCallConfig response_chat: BaseMllmChat = __config.response.chat # type: ignore[assignment] source_chat = __config.source_chat device = source_chat.torch_device mask_manager = MasksManager(chat=source_chat, log_stats=True) cache_manager = CacheManager( chat=response_chat, explainer_hash=hash(self), ) masks = [mask_manager.get_initial_mask(device=device)] responses = [__config.response] chats_skipped, history = self._generate_step( mask_manager=mask_manager, masks=masks, device=device, responses=responses, source_chat=source_chat, model=__config.model, cache_manager=cache_manager, n_generator_jobs=n_generator_jobs, progress_bar=__config.progress_bar, verbose=__config.verbose, **generate_kwargs, ) if cache_manager.extracted_num > 0: logger.info( "Deduplicated %d/%d masks using existing cache.", cache_manager.extracted_num, len(masks) - 1, # exclude base mask ) # edge case - all chats were empty after filtering yet shap_values_mask had True values # this can happen only if shap_values_mask has one True value # for simplicity we just raise an error here. # - 1 because masks will always have at least the base mask if len(masks) - 1 <= chats_skipped: raise NotEnoughTokensToExplainError( "Not enough tokens to explain after filtering out empty chats. " "Ensure that shap_values_mask has at least two True values.", ) masks_tensor = torch.stack(masks, dim=0) # clean up del mask_manager del cache_manager del masks gc.collect() # calculate SHAP values (relative to source_chat) shap_values, normalized_shap_values = self._get_shap_values( model=__config.model, masks=masks_tensor, responses=responses, source_chat=source_chat, device=device, ) # cache results self._save_to_cache( chat=response_chat, source_chat=source_chat, responses=responses, masks=masks_tensor, shap_values=shap_values, normalized_shap_values=normalized_shap_values, ) return history def __hash__(self) -> int: """ Get the hash of the explainer instance. Returns: The hash value. """ return hash( ( self.mode, self.embedding_reducer, self.similarity_measure, self.normalizer, ) )