mllm_shap.shap.pipeline.stages package#
Submodules#
mllm_shap.shap.pipeline.stages.attribution_stage module#
Attribution stage adapters.
- class mllm_shap.shap.pipeline.stages.attribution_stage.AttributionStage(get_shap_values: Callable[[...], tuple[Tensor, Tensor]])[source]#
Bases:
objectStage using high-level SHAP computation callback.
- get_shap_values: Callable[[...], tuple[Tensor, Tensor]]#
Callback for computing SHAP values. This should be a function that takes the model, masks, responses, source chat, device, and similarities as input and returns a tuple of (raw_shap_values, normalized_shap_values) as output. The raw_shap_values tensor should contain the unnormalized SHAP values for each mask, while the normalized_shap_values tensor should contain the SHAP values normalized to sum to the difference between the base response and the masked response. The exact normalization method may depend on the specific SHAP variant being implemented. The get_shap_values callback allows for flexibility in implementing different SHAP estimation methods, as the logic
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Compute raw and normalized SHAP values from current pipeline state.
mllm_shap.shap.pipeline.stages.finalize_stage module#
Normalization and cache persistence stages.
- class mllm_shap.shap.pipeline.stages.finalize_stage.FinalizeStage(save_to_cache: Callable[[...], None])[source]#
Bases:
objectFinalize stage for normalization and cache persistence.
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Persist run outputs via explainer-compatible cache callback.
- save_to_cache: Callable[[...], None]#
Callback for saving results to cache. This should be a function that takes the chat, source chat, responses, masks, raw SHAP values, and normalized SHAP values as input and saves them to the appropriate cache storage. The exact implementation of this callback will depend on the caching mechanism being used, but it should ensure that the results are stored in a way that allows for efficient retrieval and analysis later on.
mllm_shap.shap.pipeline.stages.sampling_adapter module#
Shared sampling adapter for split generation paths.
- mllm_shap.shap.pipeline.stages.sampling_adapter.build_masks_generator(get_next_split: Callable[[...], Tensor | None], get_num_splits: Callable[[int], int], mask_manager: MasksManager, device: device, masks: list[Tensor], allow_mask_duplicates: bool, allow_full_or_empty: bool) MaskGenerator[source]#
Build a mask generator using split callbacks.
- mllm_shap.shap.pipeline.stages.sampling_adapter.run_sampling_generation(get_next_split: ~typing.Callable[[...], ~torch.Tensor | None], get_num_splits: ~typing.Callable[[int], int], mask_manager: ~mllm_shap.shap.base._masks_manager.MasksManager, device: ~torch.device, masks: list[~torch.Tensor], allow_mask_duplicates: bool, allow_full_or_empty: bool, logger: ~logging.Logger, tqdm_bar: ~typing.Any | None = None, tqdm_desc: str = 'Calculating SHAP values', responses: list[~mllm_shap.connectors.base.model_response.ModelResponse] | None = None, source_chat: ~mllm_shap.connectors.base.chat.BaseMllmChat | None = None, model: ~mllm_shap.connectors.base.model.BaseMllmModel | None = None, generate_responses_fn: ~typing.Callable[[...], tuple[int, ~typing.Any]] = <function generate_responses>, get_masks_generator: ~typing.Callable[[~mllm_shap.shap.base._masks_manager.MasksManager, ~torch.device, list[~torch.Tensor], bool], ~mllm_shap.shap.base._mask_generator.MaskGenerator] | None = None, **generate_kwargs: ~typing.Any) tuple[tuple[int, list[tuple[Tensor, int, BaseMllmChat | None, ModelResponse]] | None], int][source]#
Run response generation and return result with generated mask count.
mllm_shap.shap.pipeline.stages.sampling_stage module#
Sampling stage adapters over existing SHAP generation flow.
- exception mllm_shap.shap.pipeline.stages.sampling_stage.InsufficientMasksError[source]#
Bases:
RuntimeErrorRaised when no explainable masks remain after filtering.
- add_note()#
Exception.add_note(note) โ add a note to the exception
- args#
- with_traceback()#
Exception.with_traceback(tb) โ set self.__traceback__ to tb and return self.
- class mllm_shap.shap.pipeline.stages.sampling_stage.SamplingStage(get_next_split: ~typing.Any, get_num_splits: ~typing.Any, allow_mask_duplicates: bool = False, allow_full_or_empty: bool = False, n_generator_jobs: int = 1, progress_bar: bool = True, verbose: bool = False, tqdm_desc: str = 'SHAP', generate_kwargs: dict[str, ~typing.Any] | None = None, masks_manager_factory: ~typing.Callable[[...], ~mllm_shap.shap.base._masks_manager.MasksManager] = <class 'mllm_shap.shap.base._masks_manager.MasksManager'>, cache_manager_factory: ~typing.Callable[[...], ~mllm_shap.shap.base._cache_manager.CacheManager] = <class 'mllm_shap.shap.base._cache_manager.CacheManager'>, generate_step: ~typing.Callable[[...], tuple[int, ~typing.Any]] | None = None)[source]#
Bases:
objectStage that executes split callbacks and response generation.
- allow_full_or_empty: bool = False#
Whether fully-on/full-off masks are permitted.
- allow_mask_duplicates: bool = False#
Whether duplicate masks are allowed during generation.
- cache_manager_factory#
Factory used to construct the cache manager instance.
alias of
CacheManager
- generate_kwargs: dict[str, Any] | None = None#
Optional extra keyword arguments forwarded to generation routines.
- generate_step: Callable[[...], tuple[int, Any]] | None = None#
Optional custom generation step overriding the default adapter flow.
- get_next_split: Any#
Callable returning the next sampling split/mask specification.
- get_num_splits: Any#
Callable returning total planned splits, if known.
- masks_manager_factory#
Factory used to construct the masks manager instance.
alias of
MasksManager
- n_generator_jobs: int = 1#
Number of parallel jobs used for response generation.
- progress_bar: bool = True#
Whether to display a progress bar while generating responses.
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Generate masks and model responses into pipeline state.
- tqdm_desc: str = 'SHAP'#
Label shown on the progress bar.
- verbose: bool = False#
Whether to enable verbose logging for generation internals.
mllm_shap.shap.pipeline.stages.similarity_stage module#
Similarity stage adapters.
- class mllm_shap.shap.pipeline.stages.similarity_stage.SimilarityStage(get_similarities: Callable[[...], Tensor])[source]#
Bases:
objectStage using configured similarity callback.
- get_similarities: Callable[[...], Tensor]#
Callback for computing similarities. This should be a function that takes the generated responses and the model as input and returns a tensor of similarity scores as output. The exact similarity metric used may depend on the specific SHAP variant being implemented, but common choices include cosine similarity or negative L2 distance between response embeddings. The get_similarities callback allows for flexibility in implementing different similarity measures, as the logic for computing similarities can be customized based on the requirements of the explainer and the nature of the model responses.
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Compute similarities from generated responses.
Module contents#
Pipeline stage adapters.
- class mllm_shap.shap.pipeline.stages.AttributionStage(get_shap_values: Callable[[...], tuple[Tensor, Tensor]])[source]#
Bases:
objectStage using high-level SHAP computation callback.
- get_shap_values: Callable[[...], tuple[Tensor, Tensor]]#
Callback for computing SHAP values. This should be a function that takes the model, masks, responses, source chat, device, and similarities as input and returns a tuple of (raw_shap_values, normalized_shap_values) as output. The raw_shap_values tensor should contain the unnormalized SHAP values for each mask, while the normalized_shap_values tensor should contain the SHAP values normalized to sum to the difference between the base response and the masked response. The exact normalization method may depend on the specific SHAP variant being implemented. The get_shap_values callback allows for flexibility in implementing different SHAP estimation methods, as the logic
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Compute raw and normalized SHAP values from current pipeline state.
- class mllm_shap.shap.pipeline.stages.FinalizeStage(save_to_cache: Callable[[...], None])[source]#
Bases:
objectFinalize stage for normalization and cache persistence.
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Persist run outputs via explainer-compatible cache callback.
- save_to_cache: Callable[[...], None]#
Callback for saving results to cache. This should be a function that takes the chat, source chat, responses, masks, raw SHAP values, and normalized SHAP values as input and saves them to the appropriate cache storage. The exact implementation of this callback will depend on the caching mechanism being used, but it should ensure that the results are stored in a way that allows for efficient retrieval and analysis later on.
- exception mllm_shap.shap.pipeline.stages.InsufficientMasksError[source]#
Bases:
RuntimeErrorRaised when no explainable masks remain after filtering.
- add_note()#
Exception.add_note(note) โ add a note to the exception
- args#
- with_traceback()#
Exception.with_traceback(tb) โ set self.__traceback__ to tb and return self.
- class mllm_shap.shap.pipeline.stages.SamplingStage(get_next_split: ~typing.Any, get_num_splits: ~typing.Any, allow_mask_duplicates: bool = False, allow_full_or_empty: bool = False, n_generator_jobs: int = 1, progress_bar: bool = True, verbose: bool = False, tqdm_desc: str = 'SHAP', generate_kwargs: dict[str, ~typing.Any] | None = None, masks_manager_factory: ~typing.Callable[[...], ~mllm_shap.shap.base._masks_manager.MasksManager] = <class 'mllm_shap.shap.base._masks_manager.MasksManager'>, cache_manager_factory: ~typing.Callable[[...], ~mllm_shap.shap.base._cache_manager.CacheManager] = <class 'mllm_shap.shap.base._cache_manager.CacheManager'>, generate_step: ~typing.Callable[[...], tuple[int, ~typing.Any]] | None = None)[source]#
Bases:
objectStage that executes split callbacks and response generation.
- allow_full_or_empty: bool = False#
Whether fully-on/full-off masks are permitted.
- allow_mask_duplicates: bool = False#
Whether duplicate masks are allowed during generation.
- cache_manager_factory#
alias of
CacheManager
- generate_kwargs: dict[str, Any] | None = None#
Optional extra keyword arguments forwarded to generation routines.
- generate_step: Callable[[...], tuple[int, Any]] | None = None#
Optional custom generation step overriding the default adapter flow.
- get_next_split: Any#
Callable returning the next sampling split/mask specification.
- get_num_splits: Any#
Callable returning total planned splits, if known.
- masks_manager_factory#
alias of
MasksManager
- n_generator_jobs: int = 1#
Number of parallel jobs used for response generation.
- progress_bar: bool = True#
Whether to display a progress bar while generating responses.
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Generate masks and model responses into pipeline state.
- tqdm_desc: str = 'SHAP'#
Label shown on the progress bar.
- verbose: bool = False#
Whether to enable verbose logging for generation internals.
- class mllm_shap.shap.pipeline.stages.SimilarityStage(get_similarities: Callable[[...], Tensor])[source]#
Bases:
objectStage using configured similarity callback.
- get_similarities: Callable[[...], Tensor]#
Callback for computing similarities. This should be a function that takes the generated responses and the model as input and returns a tensor of similarity scores as output. The exact similarity metric used may depend on the specific SHAP variant being implemented, but common choices include cosine similarity or negative L2 distance between response embeddings. The get_similarities callback allows for flexibility in implementing different similarity measures, as the logic for computing similarities can be customized based on the requirements of the explainer and the nature of the model responses.
- run(context: ExplainContext, state: ExplainState, probe: TelemetryProbe | None = None) None[source]#
Compute similarities from generated responses.