Source code for mllm_shap.connectors.base.explainer_cache

"""Cache for explainer computations."""

from typing import TYPE_CHECKING, Any, cast

import torch
from pydantic import BaseModel, ConfigDict
from torch import Tensor
from .model_response import ModelResponse

if TYPE_CHECKING:
    from .chat import BaseMllmChat


# pylint: disable=too-many-instance-attributes
[docs] class ExplainerCache(BaseModel): """ Cache for explainer computations associated with a chat. Saves and validates calculated SHAP values, masks, and reduced embeddings. """ model_config = ConfigDict(arbitrary_types_allowed=True) calculated_by: int """Hash of the explainer that calculated the SHAP values.""" chat: "BaseMllmChat" """The chat instance the cache is for.""" n: int """Index of last token used for SHAP calculations.""" responses: list[ModelResponse] """The model responses used for SHAP calculations.""" masks: Tensor """The masks used for SHAP calculations.""" shap_values_mask: Tensor """The mask indicating which SHAP values are relevant.""" had_different_masks: bool = False """Whether the masks used for SHAP calculations differed from chat's masks.""" _values: Tensor | None = None """The SHAP values calculated.""" _normalized_values: Tensor | None = None """The normalized SHAP values calculated.""" def __init__( self, chat: "BaseMllmChat", responses: list[ModelResponse], masks: Tensor, shap_values_mask: Tensor, **kwargs: Any, ) -> None: """ Initialize the ExplainerCache instance. Args: data: The data to initialize the instance with. """ super().__init__(chat=chat, masks=masks, responses=responses, shap_values_mask=shap_values_mask, **kwargs) if masks.shape[0] != len(responses): raise ValueError("Masks size does not match the number of responses in the chat.") self.shap_values_mask = shap_values_mask if chat is not None: if chat.input_tokens_num < masks.shape[1]: raise ValueError("Masks size is larger than the number of tokens in the chat.") # Extend masks to match chat length masks = torch.cat( [ masks, torch.full( (masks.shape[0], chat.input_tokens_num - masks.shape[1]), False, dtype=masks.dtype, device=chat.torch_device, ), ], dim=1, ) if masks.shape[1] != chat.input_tokens_num: raise ValueError("Masks size does not match the number of tokens in the chat.") self.masks = masks if torch.any(chat.shap_values_mask != shap_values_mask): self.had_different_masks = True else: self.had_different_masks = False # pylint: disable=too-many-positional-arguments,too-many-arguments
[docs] @classmethod def create( cls, chat: "BaseMllmChat", explainer_hash: int, responses: list[ModelResponse], masks: Tensor, normalized_values: Tensor, shap_values_mask: Tensor, values: Tensor | None = None, ) -> "ExplainerCache": """ Create a new ExplainerCache instance. Args: chat: The chat instance the cache is for. explainer_hash: Hash of the explainer that calculated the SHAP values. responses: The model responses used for SHAP calculations. masks: The masks used for SHAP calculations. values: The SHAP values calculated. normalized_values: The normalized SHAP values calculated. shap_values_mask: The mask indicating which SHAP values are relevant. Returns: A new ExplainerCache instance. """ instance = cls( calculated_by=explainer_hash, chat=chat, n=masks.shape[1], responses=responses, masks=masks, shap_values_mask=shap_values_mask, ) instance.normalized_values = normalized_values if values is not None: instance.values = values return instance
@property def normalized_values(self) -> Tensor: """ Normalized SHAP values. Raises: ValueError: If SHAP values are no longer valid or have not been computed yet. """ self.__validate_values_getter("_normalized_values") return cast(Tensor, self._normalized_values) @normalized_values.setter def normalized_values(self, values: Tensor) -> None: """ Set the normalized SHAP values. Args: values: The normalized SHAP values to set. Raises: ValueError: If normalized SHAP values are not valid. """ self.__values_setter("_normalized_values", values) @property def values(self) -> Tensor | None: """ SHAP values. Can be none if :class:`HierarchicalExplainer` is used. Raises: ValueError: If SHAP values are no longer valid or have not been computed yet. """ self.__validate_values_getter("_values") return cast(Tensor, self._values) @values.setter def values(self, values: Tensor | None) -> None: """ Set the SHAP values. Args: values: The SHAP values to set. Raises: ValueError: If SHAP values are not valid. """ if values is None: self._values = None return self.__values_setter("_values", values)
[docs] def extend_masks(self) -> None: """Extend masks to match the chat length.""" self.masks = ExplainerCache.extend_values( values=self.masks, shape=( self.masks.shape[0], self.chat.input_tokens_num - self.masks.shape[1], ), dim=1, fill_value=False, device=self.chat.torch_device, )
def __values_setter(self, name: str, values: Tensor) -> None: """ Set SHAP values. Args: name: The name of the SHAP values attribute to set. values: The SHAP values to set. Raises: ValueError: If SHAP values size is larger than the number of tokens in the chat or if they contain NaN values for user text tokens, or if they contain non-NaN values for non-user text tokens. """ if self.chat.input_tokens_num < values.shape[0]: raise ValueError("Values size is larger than the number of tokens in the chat.") values = ExplainerCache.extend_values( values, shape=torch.Size((self.chat.input_tokens_num - values.shape[0],)), dim=0, fill_value=float("nan"), device=self.chat.torch_device, ) if values.shape[0] != self.chat.input_tokens_num: raise ValueError("SHAP values size does not match the number of tokens in the chat.") mask = self.shap_values_mask.clone() # only validate up to n mask[self.n :] = False # noqa: E203 if values[mask].isnan().any(): raise ValueError("SHAP values contain NaN values for text tokens they should explain.") if not values[~mask].isnan().all(): raise ValueError("SHAP values contain non-NaN values for text tokens they should not explain.") setattr(self, name, values) def __validate_values_getter(self, name: str) -> None: """ Validate SHAP values when getting them. Args: values: The SHAP values to validate. Raises: ValueError: If SHAP values size does not match the number of tokens in the chat. """ if getattr(self, name) is None: raise ValueError("SHAP values have not been computed yet.") if cast(Tensor, getattr(self, name)).shape[0] != self.chat.input_tokens_num: raise ValueError( "SHAP values size does not match the number of tokens in the chat. Recalculate SHAP values to update." )
[docs] @staticmethod def extend_values( values: Tensor, shape: tuple[int, ...], dim: int, fill_value: Any, device: torch.device, ) -> Tensor: """ Extend SHAP values to match the chat length. Args: values: The SHAP values to extend. shape: The target shape for extension. dim: The dimension along which to extend. fill_value: The value to use for extension. device: The device to create the extended tensor on. Returns: The extended SHAP values. """ return torch.cat( [ values, torch.full( shape, fill_value, dtype=values.dtype, device=device, ), ], dim=dim, )
def __del__(self) -> None: """ Cleanup on deletion. Disconnect the chat to avoid circular references. """ # needs explicit None self.chat = None # type: ignore[assignment] # clear other references self.calculated_by = None # type: ignore[assignment] self.n = None # type: ignore[assignment] self.responses = None # type: ignore[assignment] self.masks = None # type: ignore[assignment] self._values = None self._normalized_values = None