Source code for mllm_shap.shap.base.explainer
"""Base class for SHAP-based explanations."""
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict
from ...connectors.base.chat import BaseMllmChat
from ...connectors.base.model import BaseMllmModel
from ...utils.logger import get_logger
from ..explainer_result import ExplainerResult
from .shap_explainer import BaseShapExplainer
if TYPE_CHECKING:
from ..core.telemetry import TelemetryProbe
logger: Logger = get_logger(__name__)
class _ExplainerConfig(BaseModel):
"""
Configuration model for Explainer.
Used just for validation and type checking.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
shap_explainer: BaseShapExplainer
model: BaseMllmModel
[docs]
class BaseExplainer(ABC):
"""Convenience base client for SHAP explainers."""
shap_explainer: BaseShapExplainer
"""The SHAP explainer instance."""
model: BaseMllmModel
"""The model connector instance."""
total_n_calls: int = 0
"""Total number of MLLM calls made for last explanation."""
def __init__(
self,
model: BaseMllmModel,
shap_explainer: BaseShapExplainer,
) -> None:
"""
Initialize the explainer.
Args:
model: The model connector instance.
shap_explainer: The SHAP explainer instance.
"""
# validation
__config = _ExplainerConfig(
shap_explainer=shap_explainer,
model=model,
)
self.shap_explainer = __config.shap_explainer
self.model = __config.model
@abstractmethod
def __call__(
self,
*_: Any,
chat: BaseMllmChat,
generation_kwargs: dict[str, Any] | None = None,
probe: "TelemetryProbe | None" = None,
**explanation_kwargs: Any,
) -> ExplainerResult:
"""
Call the explainer - generate full response from :attr:`chat`
using :attr:`model`, and then explain it using :attr:`shap_explainer`.
Args:
chat: The chat instance.
generation_kwargs: The generation kwargs for the model.generate method.
probe: Optional telemetry probe forwarded to the underlying SHAP explainer.
explanation_kwargs: The explanation kwargs for the SHAP explainer. Should not contain
duplicate keys with generation_kwargs.
Returns:
The ExplainerResult instance.
Raises:
ValueError: If generation_kwargs or explanation_kwargs contain invalid keys or duplicate keys.
"""
generation_kwargs = generation_kwargs or {}
if "chat" in generation_kwargs or "keep_history" in generation_kwargs:
raise ValueError(
"generation_kwargs should not contain 'chat' or 'keep_history' keys."
)
if (
"chat" in explanation_kwargs
or "base_chat" in explanation_kwargs
or "model" in explanation_kwargs
):
raise ValueError(
"explanation_kwargs should not contain 'chat', 'base_chat' or 'model' keys."
)
# ensure there are no duplicate keys between generation_kwargs and explanation_kwargs
common_keys = set(generation_kwargs.keys()) & set(explanation_kwargs.keys())
if common_keys:
raise ValueError(
f"Duplicate keys found in generation_kwargs and explanation_kwargs: {sorted(common_keys)}"
)
self.total_n_calls = 0