mllm_shap.shap.core package#

Submodules#

mllm_shap.shap.core.contracts module#

Core sampling contracts for composition-based SHAP internals.

class mllm_shap.shap.core.contracts.EstimationResult(values: Tensor, uncertainty: Tensor | None = None)[source]#

Bases: object

Container for estimator outputs.

values#

Estimated attribution values.

Type:

torch.Tensor

uncertainty#

Optional uncertainty proxy (for example CI width or variance). None means uncertainty is not provided by this estimator.

Type:

torch.Tensor | None

uncertainty: Tensor | None = None#

Optional uncertainty proxy aligned with values when available.

values: Tensor#

Estimated attribution values returned by the estimator.

class mllm_shap.shap.core.contracts.Estimator[source]#

Bases: ABC

Contract for composable attribution estimators.

abstractmethod estimate(masks: Tensor, payoffs: Tensor) EstimationResult[source]#

Estimate attributions from sampled masks and corresponding payoffs.

class mllm_shap.shap.core.contracts.SamplingStrategy[source]#

Bases: ABC

Contract for split-sampling strategies.

abstractmethod get_next_split(n: int, device: device, generated_masks_num: int, existing_masks: list[Tensor] | None = None) Tensor | None[source]#

Return next split or None when strategy is exhausted.

abstractmethod get_num_splits(n: int) int | None[source]#

Return expected number of generated masks, if known.

class mllm_shap.shap.core.contracts.StopDecision(should_stop: bool, reason: str)[source]#

Bases: object

Decision returned by stopping policies.

reason: str#

Human-readable explanation for the stop/continue decision.

should_stop: bool#

Whether the iterative estimation loop should terminate now.

class mllm_shap.shap.core.contracts.StoppingPolicy[source]#

Bases: ABC

Contract for composable stopping policies.

abstractmethod should_stop(iteration: int, estimation: EstimationResult) StopDecision[source]#

Return whether pipeline should stop and why.

mllm_shap.shap.core.engine module#

Sampling engines extracted from the explainer inheritance flow.

class mllm_shap.shap.core.engine.SamplingEngine(strategy: SamplingStrategy, allow_mask_duplicates: bool = False, allow_full_or_empty: bool = False, probe: TelemetryProbe | None = None)[source]#

Bases: object

Generic mask-generation engine driven by a sampling strategy.

create_generator(mask_manager: MasksManager, device: device, masks: list[Tensor]) MaskGenerator[source]#

Create a mask generator compatible with existing SHAP flow.

class mllm_shap.shap.core.engine.SamplingStats(candidate_splits: int = 0, yielded_masks: int = 0, skipped_full_or_empty: int = 0, skipped_invalid_masks: int = 0, skipped_duplicates: int = 0, elapsed_ms: float = 0.0)[source]#

Bases: object

Sampling runtime and filtering counters.

candidate_splits: int = 0#

Number of candidate splits proposed by the sampling strategy.

elapsed_ms: float = 0.0#

Total generator runtime in milliseconds.

skipped_duplicates: int = 0#

Number of masks skipped because they duplicated an already seen mask.

skipped_full_or_empty: int = 0#

Number of masks skipped because they were full or empty.

skipped_invalid_masks: int = 0#

Number of masks skipped because they produced invalid prepared masks.

yielded_masks: int = 0#

Number of masks yielded to the response-generation pipeline.

mllm_shap.shap.core.estimation module#

Composable estimation and stopping adapters for SHAP pipelines.

class mllm_shap.shap.core.estimation.CallableEstimator(estimate_fn: Callable[[Tensor, Tensor], Tensor | tuple[Tensor, Tensor]])[source]#

Bases: Estimator

Adapter that bridges callables to the Estimator contract.

estimate(masks: Tensor, payoffs: Tensor) EstimationResult[source]#

Delegate estimation to wrapped callable.

class mllm_shap.shap.core.estimation.CallableStoppingPolicy(should_stop_fn: Callable[[int, EstimationResult], bool | StopDecision], default_reason: str = 'callable-policy')[source]#

Bases: StoppingPolicy

Adapter that bridges callables to the StoppingPolicy contract.

should_stop(iteration: int, estimation: EstimationResult) StopDecision[source]#

Delegate stopping decision to wrapped callable.

class mllm_shap.shap.core.estimation.FixedThresholdStoppingPolicy(threshold: float)[source]#

Bases: StoppingPolicy

Stop when scalar uncertainty falls below a configured threshold.

should_stop(iteration: int, estimation: EstimationResult) StopDecision[source]#

Apply fixed uncertainty threshold rule.

mllm_shap.shap.core.sampling module#

Sampling strategy adapters used by the composition-based sampling engines.

class mllm_shap.shap.core.sampling.CallableAdapterStrategy(get_next_split: Callable[[int, device, int, list[Tensor] | None], Tensor | None], get_num_splits: Callable[[int], int | None])[source]#

Bases: SamplingStrategy

Adapter that bridges callable functions to SamplingStrategy contract.

get_next_split(n: int, device: device, generated_masks_num: int, existing_masks: list[Tensor] | None = None) Tensor | None[source]#

Delegate split generation to the wrapped callable.

Parameters:
  • n โ€“ Number of explainable features.

  • device โ€“ Device on which the split should be created.

  • generated_masks_num โ€“ Number of masks generated so far.

  • existing_masks โ€“ Existing masks available to the strategy.

Returns:

The next split mask, or None when generation should stop.

get_num_splits(n: int) int | None[source]#

Return the expected number of splits for n explainable features.

mllm_shap.shap.core.telemetry module#

Telemetry and metrics collection for SHAP explainability.

class mllm_shap.shap.core.telemetry.CacheMetrics(hits: int = 0, misses: int = 0)[source]#

Bases: object

Metrics for cache operations.

property hit_rate: float#

Cache hit rate (0-1).

hits: int = 0#

Number of cache hits.

misses: int = 0#

Number of cache misses.

to_dict() dict[str, Any][source]#

Convert metrics to dictionary.

property total: int#

Total cache operations.

class mllm_shap.shap.core.telemetry.JSONProbeSink[source]#

Bases: ProbeSink

Telemetry sink that collects metrics for JSON serialization.

get_metrics() TelemetryData[source]#

Get collected telemetry data.

record_cache_operation(is_hit: bool) None[source]#

Record a cache operation.

record_custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

record_mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask.

record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage.

reset() None[source]#

Reset collected metrics.

to_json() str[source]#

Serialize collected metrics to JSON string.

class mllm_shap.shap.core.telemetry.LogProbeSink(verbose: bool = False)[source]#

Bases: ProbeSink

Telemetry sink that logs metrics to logger.

get_metrics() TelemetryData[source]#

Get collected telemetry data.

record_cache_operation(is_hit: bool) None[source]#

Record a cache operation.

record_custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

record_mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask.

record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage.

reset() None[source]#

Reset collected metrics.

class mllm_shap.shap.core.telemetry.MaskMetrics(generated: int = 0, unique: int = 0, invalid: int = 0)[source]#

Bases: object

Metrics for mask generation and deduplication.

property dedup_rate: float#

Rate of duplicates caught (0-1).

generated: int = 0#

Total masks generated.

invalid: int = 0#

Total invalid/empty masks.

to_dict() dict[str, Any][source]#

Convert metrics to dictionary.

unique: int = 0#

Total unique masks (after dedup).

class mllm_shap.shap.core.telemetry.ProbeSink[source]#

Bases: ABC

Abstract base class for telemetry sinks.

abstractmethod get_metrics() TelemetryData[source]#

Get collected telemetry data.

abstractmethod record_cache_operation(is_hit: bool) None[source]#

Record a cache operation (hit or miss).

abstractmethod record_custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

abstractmethod record_mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask and whether it was unique/invalid.

abstractmethod record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage (sampling, dedup, masking, model, scoring).

reset() None[source]#

Reset collected metrics. Optional override.

class mllm_shap.shap.core.telemetry.StageTimer(probe: TelemetryProbe, stage: str)[source]#

Bases: object

Context manager for timing code blocks with probe integration.

measure() Generator[None, None, None][source]#

Alternate context manager interface.

start_time: float | None#
class mllm_shap.shap.core.telemetry.TelemetryData(cache_metrics: CacheMetrics = <factory>, mask_metrics: MaskMetrics = <factory>, timing_metrics: TimingMetrics = <factory>, custom_metrics: dict[str, ~typing.Any]=<factory>)[source]#

Bases: object

Container for all telemetry data collected during a run.

cache_metrics: CacheMetrics#

Cache operation metrics.

custom_metrics: dict[str, Any]#

Custom metrics dictionary.

mask_metrics: MaskMetrics#

Mask generation metrics.

timing_metrics: TimingMetrics#

Per-stage timing metrics.

to_dict() dict[str, Any][source]#

Convert all telemetry to dictionary.

class mllm_shap.shap.core.telemetry.TelemetryProbe(sink: ProbeSink | None = None)[source]#

Bases: object

Main interface for collecting telemetry during SHAP computation.

cache_operation(is_hit: bool) None[source]#

Record a cache operation (hit or miss).

custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

get_metrics() TelemetryData | None[source]#

Get collected telemetry data, or None if no sink.

mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask and whether it was unique/invalid.

static noop() TelemetryProbe[source]#

Create a no-op probe (no sink, no collection).

record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage (sampling, dedup, masking, model, scoring).

reset() None[source]#

Reset collected metrics.

timing(stage: str) StageTimer[source]#

Create a context manager for timing a stage.

static with_json_sink() TelemetryProbe[source]#

Create a probe with JSON sink.

static with_log_sink(verbose: bool = False) TelemetryProbe[source]#

Create a probe with log sink.

class mllm_shap.shap.core.telemetry.TimingMetrics(sampling_ms: float = 0.0, dedup_ms: float = 0.0, masking_ms: float = 0.0, model_ms: float = 0.0, scoring_ms: float = 0.0)[source]#

Bases: object

Metrics for per-stage timing.

dedup_ms: float = 0.0#

Time spent in deduplication checks (milliseconds).

masking_ms: float = 0.0#

Time spent in mask preparation (milliseconds).

model_ms: float = 0.0#

Time spent in model inference (milliseconds).

sampling_ms: float = 0.0#

Time spent in mask sampling (milliseconds).

scoring_ms: float = 0.0#

Time spent in SHAP value computation (milliseconds).

to_dict() dict[str, Any][source]#

Convert metrics to dictionary.

property total_ms: float#

Total time across all stages.

Module contents#

Core composable building blocks for SHAP internals.

class mllm_shap.shap.core.CacheMetrics(hits: int = 0, misses: int = 0)[source]#

Bases: object

Metrics for cache operations.

property hit_rate: float#

Cache hit rate (0-1).

hits: int = 0#

Number of cache hits.

misses: int = 0#

Number of cache misses.

to_dict() dict[str, Any][source]#

Convert metrics to dictionary.

property total: int#

Total cache operations.

class mllm_shap.shap.core.CallableAdapterStrategy(get_next_split: Callable[[int, device, int, list[Tensor] | None], Tensor | None], get_num_splits: Callable[[int], int | None])[source]#

Bases: SamplingStrategy

Adapter that bridges callable functions to SamplingStrategy contract.

get_next_split(n: int, device: device, generated_masks_num: int, existing_masks: list[Tensor] | None = None) Tensor | None[source]#

Delegate split generation to the wrapped callable.

Parameters:
  • n โ€“ Number of explainable features.

  • device โ€“ Device on which the split should be created.

  • generated_masks_num โ€“ Number of masks generated so far.

  • existing_masks โ€“ Existing masks available to the strategy.

Returns:

The next split mask, or None when generation should stop.

get_num_splits(n: int) int | None[source]#

Return the expected number of splits for n explainable features.

class mllm_shap.shap.core.CallableEstimator(estimate_fn: Callable[[Tensor, Tensor], Tensor | tuple[Tensor, Tensor]])[source]#

Bases: Estimator

Adapter that bridges callables to the Estimator contract.

estimate(masks: Tensor, payoffs: Tensor) EstimationResult[source]#

Delegate estimation to wrapped callable.

class mllm_shap.shap.core.CallableStoppingPolicy(should_stop_fn: Callable[[int, EstimationResult], bool | StopDecision], default_reason: str = 'callable-policy')[source]#

Bases: StoppingPolicy

Adapter that bridges callables to the StoppingPolicy contract.

should_stop(iteration: int, estimation: EstimationResult) StopDecision[source]#

Delegate stopping decision to wrapped callable.

class mllm_shap.shap.core.EstimationResult(values: Tensor, uncertainty: Tensor | None = None)[source]#

Bases: object

Container for estimator outputs.

values#

Estimated attribution values.

Type:

torch.Tensor

uncertainty#

Optional uncertainty proxy (for example CI width or variance). None means uncertainty is not provided by this estimator.

Type:

torch.Tensor | None

uncertainty: Tensor | None = None#

Optional uncertainty proxy aligned with values when available.

values: Tensor#

Estimated attribution values returned by the estimator.

class mllm_shap.shap.core.Estimator[source]#

Bases: ABC

Contract for composable attribution estimators.

abstractmethod estimate(masks: Tensor, payoffs: Tensor) EstimationResult[source]#

Estimate attributions from sampled masks and corresponding payoffs.

class mllm_shap.shap.core.FixedThresholdStoppingPolicy(threshold: float)[source]#

Bases: StoppingPolicy

Stop when scalar uncertainty falls below a configured threshold.

should_stop(iteration: int, estimation: EstimationResult) StopDecision[source]#

Apply fixed uncertainty threshold rule.

class mllm_shap.shap.core.JSONProbeSink[source]#

Bases: ProbeSink

Telemetry sink that collects metrics for JSON serialization.

get_metrics() TelemetryData[source]#

Get collected telemetry data.

record_cache_operation(is_hit: bool) None[source]#

Record a cache operation.

record_custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

record_mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask.

record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage.

reset() None[source]#

Reset collected metrics.

to_json() str[source]#

Serialize collected metrics to JSON string.

class mllm_shap.shap.core.LogProbeSink(verbose: bool = False)[source]#

Bases: ProbeSink

Telemetry sink that logs metrics to logger.

get_metrics() TelemetryData[source]#

Get collected telemetry data.

record_cache_operation(is_hit: bool) None[source]#

Record a cache operation.

record_custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

record_mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask.

record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage.

reset() None[source]#

Reset collected metrics.

class mllm_shap.shap.core.MaskMetrics(generated: int = 0, unique: int = 0, invalid: int = 0)[source]#

Bases: object

Metrics for mask generation and deduplication.

property dedup_rate: float#

Rate of duplicates caught (0-1).

generated: int = 0#

Total masks generated.

invalid: int = 0#

Total invalid/empty masks.

to_dict() dict[str, Any][source]#

Convert metrics to dictionary.

unique: int = 0#

Total unique masks (after dedup).

class mllm_shap.shap.core.ProbeSink[source]#

Bases: ABC

Abstract base class for telemetry sinks.

abstractmethod get_metrics() TelemetryData[source]#

Get collected telemetry data.

abstractmethod record_cache_operation(is_hit: bool) None[source]#

Record a cache operation (hit or miss).

abstractmethod record_custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

abstractmethod record_mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask and whether it was unique/invalid.

abstractmethod record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage (sampling, dedup, masking, model, scoring).

reset() None[source]#

Reset collected metrics. Optional override.

class mllm_shap.shap.core.SamplingEngine(strategy: SamplingStrategy, allow_mask_duplicates: bool = False, allow_full_or_empty: bool = False, probe: TelemetryProbe | None = None)[source]#

Bases: object

Generic mask-generation engine driven by a sampling strategy.

create_generator(mask_manager: MasksManager, device: device, masks: list[Tensor]) MaskGenerator[source]#

Create a mask generator compatible with existing SHAP flow.

class mllm_shap.shap.core.SamplingStats(candidate_splits: int = 0, yielded_masks: int = 0, skipped_full_or_empty: int = 0, skipped_invalid_masks: int = 0, skipped_duplicates: int = 0, elapsed_ms: float = 0.0)[source]#

Bases: object

Sampling runtime and filtering counters.

candidate_splits: int = 0#

Number of candidate splits proposed by the sampling strategy.

elapsed_ms: float = 0.0#

Total generator runtime in milliseconds.

skipped_duplicates: int = 0#

Number of masks skipped because they duplicated an already seen mask.

skipped_full_or_empty: int = 0#

Number of masks skipped because they were full or empty.

skipped_invalid_masks: int = 0#

Number of masks skipped because they produced invalid prepared masks.

yielded_masks: int = 0#

Number of masks yielded to the response-generation pipeline.

class mllm_shap.shap.core.SamplingStrategy[source]#

Bases: ABC

Contract for split-sampling strategies.

abstractmethod get_next_split(n: int, device: device, generated_masks_num: int, existing_masks: list[Tensor] | None = None) Tensor | None[source]#

Return next split or None when strategy is exhausted.

abstractmethod get_num_splits(n: int) int | None[source]#

Return expected number of generated masks, if known.

class mllm_shap.shap.core.StageTimer(probe: TelemetryProbe, stage: str)[source]#

Bases: object

Context manager for timing code blocks with probe integration.

measure() Generator[None, None, None][source]#

Alternate context manager interface.

start_time: float | None#
class mllm_shap.shap.core.StopDecision(should_stop: bool, reason: str)[source]#

Bases: object

Decision returned by stopping policies.

reason: str#

Human-readable explanation for the stop/continue decision.

should_stop: bool#

Whether the iterative estimation loop should terminate now.

class mllm_shap.shap.core.StoppingPolicy[source]#

Bases: ABC

Contract for composable stopping policies.

abstractmethod should_stop(iteration: int, estimation: EstimationResult) StopDecision[source]#

Return whether pipeline should stop and why.

class mllm_shap.shap.core.TelemetryData(cache_metrics: CacheMetrics = <factory>, mask_metrics: MaskMetrics = <factory>, timing_metrics: TimingMetrics = <factory>, custom_metrics: dict[str, ~typing.Any]=<factory>)[source]#

Bases: object

Container for all telemetry data collected during a run.

cache_metrics: CacheMetrics#

Cache operation metrics.

custom_metrics: dict[str, Any]#

Custom metrics dictionary.

mask_metrics: MaskMetrics#

Mask generation metrics.

timing_metrics: TimingMetrics#

Per-stage timing metrics.

to_dict() dict[str, Any][source]#

Convert all telemetry to dictionary.

class mllm_shap.shap.core.TelemetryProbe(sink: ProbeSink | None = None)[source]#

Bases: object

Main interface for collecting telemetry during SHAP computation.

cache_operation(is_hit: bool) None[source]#

Record a cache operation (hit or miss).

custom_metric(key: str, value: Any) None[source]#

Record a custom metric.

get_metrics() TelemetryData | None[source]#

Get collected telemetry data, or None if no sink.

mask_generated(is_unique: bool, is_invalid: bool = False) None[source]#

Record a generated mask and whether it was unique/invalid.

static noop() TelemetryProbe[source]#

Create a no-op probe (no sink, no collection).

record_timing(stage: str, elapsed_ms: float) None[source]#

Record timing for a stage (sampling, dedup, masking, model, scoring).

reset() None[source]#

Reset collected metrics.

timing(stage: str) StageTimer[source]#

Create a context manager for timing a stage.

static with_json_sink() TelemetryProbe[source]#

Create a probe with JSON sink.

static with_log_sink(verbose: bool = False) TelemetryProbe[source]#

Create a probe with log sink.

class mllm_shap.shap.core.TimingMetrics(sampling_ms: float = 0.0, dedup_ms: float = 0.0, masking_ms: float = 0.0, model_ms: float = 0.0, scoring_ms: float = 0.0)[source]#

Bases: object

Metrics for per-stage timing.

dedup_ms: float = 0.0#

Time spent in deduplication checks (milliseconds).

masking_ms: float = 0.0#

Time spent in mask preparation (milliseconds).

model_ms: float = 0.0#

Time spent in model inference (milliseconds).

sampling_ms: float = 0.0#

Time spent in mask sampling (milliseconds).

scoring_ms: float = 0.0#

Time spent in SHAP value computation (milliseconds).

to_dict() dict[str, Any][source]#

Convert metrics to dictionary.

property total_ms: float#

Total time across all stages.