Source code for mllm_shap.shap.pipeline.presets
"""Pipeline presets for explainers."""
from dataclasses import dataclass
from typing import Any
from .pipeline import ExplainPipeline
from .stages import (
AttributionStage,
FinalizeStage,
SamplingStage,
SimilarityStage,
)
[docs]
@dataclass(frozen=True)
class PipelinePreset:
"""Factory that wires explainer callbacks into composition pipeline."""
get_next_split: Any
"""Callback that returns the next split/mask proposal."""
get_num_splits: Any
"""Callback that reports expected number of splits, if known."""
get_similarities: Any
"""Callback computing similarities from sampled responses."""
get_shap_values: Any
"""Callback converting similarities into SHAP attributions."""
save_to_cache: Any
"""Callback persisting final values into the explainer cache."""
allow_mask_duplicates: bool
"""Whether duplicate masks are allowed in the sampling stage."""
allow_full_or_empty: bool
"""Whether all-on/all-off masks are allowed during sampling."""
n_generator_jobs: int
"""Number of parallel jobs used for generation."""
progress_bar: bool
"""Whether to show progress bars in pipeline stages."""
verbose: bool
"""Whether to enable verbose diagnostic logging."""
tqdm_desc: str
"""Progress bar label used by generation stages."""
generate_kwargs: dict[str, Any]
"""Extra keyword arguments forwarded into generation routines."""
masks_manager_factory: Any | None = None
"""Optional factory overriding mask manager construction."""
cache_manager_factory: Any | None = None
"""Optional factory overriding cache manager construction."""
generate_step: Any | None = None
"""Optional custom generation callable replacing default sampling adapter."""
include_sampling: bool = True
"""Whether to include the sampling stage in the built pipeline."""
include_attribution: bool = True
"""Whether to include the attribution stage in the built pipeline."""
include_finalize: bool = True
"""Whether to include the finalize/cache stage in the built pipeline."""
[docs]
def build(self) -> ExplainPipeline:
"""Create executable pipeline for current explainer."""
stages: list[Any] = []
if self.include_sampling:
stages.append(
SamplingStage(
get_next_split=self.get_next_split,
get_num_splits=self.get_num_splits,
allow_mask_duplicates=self.allow_mask_duplicates,
allow_full_or_empty=self.allow_full_or_empty,
n_generator_jobs=self.n_generator_jobs,
progress_bar=self.progress_bar,
verbose=self.verbose,
tqdm_desc=self.tqdm_desc,
generate_kwargs=self.generate_kwargs,
masks_manager_factory=self.masks_manager_factory,
cache_manager_factory=self.cache_manager_factory,
generate_step=self.generate_step,
)
)
stages.append(SimilarityStage(get_similarities=self.get_similarities))
if self.include_attribution:
stages.append(AttributionStage(get_shap_values=self.get_shap_values))
if self.include_finalize:
stages.append(FinalizeStage(save_to_cache=self.save_to_cache))
return ExplainPipeline(stages=tuple(stages))