Source code for mllm_shap.shap.compact

"""Compact SHAP explainer implementation."""

from logging import Logger
from time import time
from typing import Any


from ..connectors.base.chat import BaseMllmChat
from ..utils.logger import get_logger
from .base.explainer import BaseExplainer
from .base.shap_explainer import BaseShapExplainer
from .explainer_result import ExplainerResult
from .precise import PreciseShapExplainer

logger: Logger = get_logger(__name__)


# pylint: disable=too-few-public-methods
[docs] class Explainer(BaseExplainer): """ Convenience client class for SHAP explanation. It generates the full response from the model and then uses the provided SHAP explainer to compute SHAP values. Uses :class:`PreciseShapExplainer` as the default SHAP explainer. """ def __init__( self, shap_explainer: BaseShapExplainer | None = None, **kwargs: Any, ) -> None: super().__init__( shap_explainer=shap_explainer or PreciseShapExplainer(), **kwargs, ) def __call__( self, *_: Any, chat: BaseMllmChat, generation_kwargs: dict[str, Any] | None = None, **explanation_kwargs: Any, ) -> ExplainerResult: generation_kwargs = generation_kwargs or {} # validation super().__call__( chat=chat, generation_kwargs=generation_kwargs, **explanation_kwargs, ) t0 = time() logger.info("Generating full response from the model...") response = self.model.generate( chat=chat, keep_history=True, **generation_kwargs, ) logger.debug("Generation took %.2f seconds.", time() - t0) del chat.cache # free memory t0 = time() history = self.shap_explainer( model=self.model, source_chat=chat, response=response, **explanation_kwargs, **generation_kwargs, ) logger.debug("Explanation took %.2f seconds.", time() - t0) self.total_n_calls = self.shap_explainer.total_n_calls return ExplainerResult( source_chat=chat, # chat is set as generate was called with keep_history=True full_chat=response.chat, # type: ignore[arg-type] history=history, total_n_calls=self.shap_explainer.total_n_calls, )