mllm_shap.shap.pipeline package#

Subpackages#

Submodules#

mllm_shap.shap.pipeline.context module#

Pipeline context/state for composition-first SHAP execution.

class mllm_shap.shap.pipeline.context.ExplainContext(model: BaseMllmModel, source_chat: BaseMllmChat, response_chat: BaseMllmChat, base_response: ModelResponse, device: device, budget: int | None = None, seed: int | None = None, params: MappingProxyType = <factory>)[source]#

Bases: object

Immutable input context for one explanation run.

base_response: ModelResponse#

Reference model response for the unmasked input.

budget: int | None = None#

Optional cap on the number of generated masks/samples.

device: device#

Torch device where SHAP computations are executed.

model: BaseMllmModel#

Model connector used to generate responses and embeddings.

params: MappingProxyType#

Immutable free-form parameters passed to pipeline stages.

response_chat: BaseMllmChat#

Chat instance representing the reference/response branch.

seed: int | None = None#

Optional random seed controlling stochastic sampling behavior.

source_chat: BaseMllmChat#

Original chat instance to explain.

class mllm_shap.shap.pipeline.context.ExplainState(masks: list[Tensor] = <factory>, responses: list[ModelResponse] = <factory>, similarities: Tensor | None = None, shap_values: Tensor | None = None, normalized_shap_values: Tensor | None = None, history: list[tuple[~torch.Tensor, int, ~mllm_shap.connectors.base.chat.BaseMllmChat | None, ~mllm_shap.connectors.base.model_response.ModelResponse]] | None=None, metadata: dict[str, ~typing.Any]=<factory>)[source]#

Bases: object

Mutable run state passed across pipeline stages.

add_metadata(key: str, value: Any) None[source]#

Attach stage-local metadata for diagnostics/telemetry.

history: list[tuple[Tensor, int, BaseMllmChat | None, ModelResponse]] | None = None#

Optional per-iteration generation history for debugging/analysis.

masks: list[Tensor]#

Accumulated masks generated across sampling stages.

metadata: dict[str, Any]#

Stage-level telemetry and diagnostic metadata.

normalized_shap_values: Tensor | None = None#

Post-processed SHAP values after normalization.

responses: list[ModelResponse]#

Model responses aligned with the generated masks.

shap_values: Tensor | None = None#

Raw SHAP attribution values before normalization.

similarities: Tensor | None = None#

Similarity scores computed for current responses/masks.

mllm_shap.shap.pipeline.contracts module#

Contracts for composition-first SHAP pipeline stages and policies.

class mllm_shap.shap.pipeline.contracts.EstimationPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy estimating SHAP values from masks/similarities.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.contracts.NormalizationPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy normalizing raw SHAP values.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.contracts.PersistencePolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy persisting results (for example cache storage).

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.contracts.PipelineStage(*args, **kwargs)[source]#

Bases: Protocol

Single pipeline stage that can mutate ExplainState.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.contracts.SamplingPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy producing masks and response samples.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.contracts.SimilarityPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy computing similarities from responses.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

mllm_shap.shap.pipeline.pipeline module#

Generic explain pipeline executor.

class mllm_shap.shap.pipeline.pipeline.ExplainPipeline(stages: tuple[PipelineStage, ...])[source]#

Bases: object

Ordered stage executor for SHAP runs.

run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) ExplainState[source]#

Execute stages in order and return mutated state.

stages: tuple[PipelineStage, ...]#

Ordered tuple of pipeline stages to execute. Each stage should be an instance of a class that implements the PipelineStage protocol, meaning it should have a run method that takes an ExplainContext, ExplainState, and optional TelemetryProbe as

mllm_shap.shap.pipeline.presets module#

Pipeline presets for explainers.

class mllm_shap.shap.pipeline.presets.PipelinePreset(get_next_split: Any, get_num_splits: Any, get_similarities: Any, get_shap_values: Any, save_to_cache: Any, allow_mask_duplicates: bool, allow_full_or_empty: bool, n_generator_jobs: int, progress_bar: bool, verbose: bool, tqdm_desc: str, generate_kwargs: dict[str, Any], masks_manager_factory: Any | None = None, cache_manager_factory: Any | None = None, generate_step: Any | None = None, include_sampling: bool = True, include_attribution: bool = True, include_finalize: bool = True)[source]#

Bases: object

Factory that wires explainer callbacks into composition pipeline.

allow_full_or_empty: bool#

Whether all-on/all-off masks are allowed during sampling.

allow_mask_duplicates: bool#

Whether duplicate masks are allowed in the sampling stage.

build() ExplainPipeline[source]#

Create executable pipeline for current explainer.

cache_manager_factory: Any | None = None#

Optional factory overriding cache manager construction.

generate_kwargs: dict[str, Any]#

Extra keyword arguments forwarded into generation routines.

generate_step: Any | None = None#

Optional custom generation callable replacing default sampling adapter.

get_next_split: Any#

Callback that returns the next split/mask proposal.

get_num_splits: Any#

Callback that reports expected number of splits, if known.

get_shap_values: Any#

Callback converting similarities into SHAP attributions.

get_similarities: Any#

Callback computing similarities from sampled responses.

include_attribution: bool = True#

Whether to include the attribution stage in the built pipeline.

include_finalize: bool = True#

Whether to include the finalize/cache stage in the built pipeline.

include_sampling: bool = True#

Whether to include the sampling stage in the built pipeline.

masks_manager_factory: Any | None = None#

Optional factory overriding mask manager construction.

n_generator_jobs: int#

Number of parallel jobs used for generation.

progress_bar: bool#

Whether to show progress bars in pipeline stages.

save_to_cache: Any#

Callback persisting final values into the explainer cache.

tqdm_desc: str#

Progress bar label used by generation stages.

verbose: bool#

Whether to enable verbose diagnostic logging.

Module contents#

Composition-first SHAP execution pipeline.

class mllm_shap.shap.pipeline.EstimationPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy estimating SHAP values from masks/similarities.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.ExplainContext(model: BaseMllmModel, source_chat: BaseMllmChat, response_chat: BaseMllmChat, base_response: ModelResponse, device: device, budget: int | None = None, seed: int | None = None, params: MappingProxyType = <factory>)[source]#

Bases: object

Immutable input context for one explanation run.

base_response: ModelResponse#

Reference model response for the unmasked input.

budget: int | None = None#

Optional cap on the number of generated masks/samples.

device: device#

Torch device where SHAP computations are executed.

model: BaseMllmModel#

Model connector used to generate responses and embeddings.

params: MappingProxyType#

Immutable free-form parameters passed to pipeline stages.

response_chat: BaseMllmChat#

Chat instance representing the reference/response branch.

seed: int | None = None#

Optional random seed controlling stochastic sampling behavior.

source_chat: BaseMllmChat#

Original chat instance to explain.

class mllm_shap.shap.pipeline.ExplainPipeline(stages: tuple[PipelineStage, ...])[source]#

Bases: object

Ordered stage executor for SHAP runs.

run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) ExplainState[source]#

Execute stages in order and return mutated state.

stages: tuple[PipelineStage, ...]#

Ordered tuple of pipeline stages to execute. Each stage should be an instance of a class that implements the PipelineStage protocol, meaning it should have a run method that takes an ExplainContext, ExplainState, and optional TelemetryProbe as

class mllm_shap.shap.pipeline.ExplainState(masks: list[Tensor] = <factory>, responses: list[ModelResponse] = <factory>, similarities: Tensor | None = None, shap_values: Tensor | None = None, normalized_shap_values: Tensor | None = None, history: list[tuple[~torch.Tensor, int, ~mllm_shap.connectors.base.chat.BaseMllmChat | None, ~mllm_shap.connectors.base.model_response.ModelResponse]] | None=None, metadata: dict[str, ~typing.Any]=<factory>)[source]#

Bases: object

Mutable run state passed across pipeline stages.

add_metadata(key: str, value: Any) None[source]#

Attach stage-local metadata for diagnostics/telemetry.

history: list[tuple[Tensor, int, BaseMllmChat | None, ModelResponse]] | None = None#

Optional per-iteration generation history for debugging/analysis.

masks: list[Tensor]#

Accumulated masks generated across sampling stages.

metadata: dict[str, Any]#

Stage-level telemetry and diagnostic metadata.

normalized_shap_values: Tensor | None = None#

Post-processed SHAP values after normalization.

responses: list[ModelResponse]#

Model responses aligned with the generated masks.

shap_values: Tensor | None = None#

Raw SHAP attribution values before normalization.

similarities: Tensor | None = None#

Similarity scores computed for current responses/masks.

class mllm_shap.shap.pipeline.NormalizationPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy normalizing raw SHAP values.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.PersistencePolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy persisting results (for example cache storage).

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.PipelinePreset(get_next_split: Any, get_num_splits: Any, get_similarities: Any, get_shap_values: Any, save_to_cache: Any, allow_mask_duplicates: bool, allow_full_or_empty: bool, n_generator_jobs: int, progress_bar: bool, verbose: bool, tqdm_desc: str, generate_kwargs: dict[str, Any], masks_manager_factory: Any | None = None, cache_manager_factory: Any | None = None, generate_step: Any | None = None, include_sampling: bool = True, include_attribution: bool = True, include_finalize: bool = True)[source]#

Bases: object

Factory that wires explainer callbacks into composition pipeline.

allow_full_or_empty: bool#

Whether all-on/all-off masks are allowed during sampling.

allow_mask_duplicates: bool#

Whether duplicate masks are allowed in the sampling stage.

build() ExplainPipeline[source]#

Create executable pipeline for current explainer.

cache_manager_factory: Any | None = None#

Optional factory overriding cache manager construction.

generate_kwargs: dict[str, Any]#

Extra keyword arguments forwarded into generation routines.

generate_step: Any | None = None#

Optional custom generation callable replacing default sampling adapter.

get_next_split: Any#

Callback that returns the next split/mask proposal.

get_num_splits: Any#

Callback that reports expected number of splits, if known.

get_shap_values: Any#

Callback converting similarities into SHAP attributions.

get_similarities: Any#

Callback computing similarities from sampled responses.

include_attribution: bool = True#

Whether to include the attribution stage in the built pipeline.

include_finalize: bool = True#

Whether to include the finalize/cache stage in the built pipeline.

include_sampling: bool = True#

Whether to include the sampling stage in the built pipeline.

masks_manager_factory: Any | None = None#

Optional factory overriding mask manager construction.

n_generator_jobs: int#

Number of parallel jobs used for generation.

progress_bar: bool#

Whether to show progress bars in pipeline stages.

save_to_cache: Any#

Callback persisting final values into the explainer cache.

tqdm_desc: str#

Progress bar label used by generation stages.

verbose: bool#

Whether to enable verbose diagnostic logging.

class mllm_shap.shap.pipeline.PipelineStage(*args, **kwargs)[source]#

Bases: Protocol

Single pipeline stage that can mutate ExplainState.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.SamplingPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy producing masks and response samples.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.

class mllm_shap.shap.pipeline.SimilarityPolicy(*args, **kwargs)[source]#

Bases: PipelineStage, Protocol

Policy computing similarities from responses.

abstractmethod run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None#

Execute stage logic and mutate state.