mllm_shap.shap package#
Subpackages#
- mllm_shap.shap.base package
- Submodules
- mllm_shap.shap.base.approx module
BaseShapApproximationBaseShapApproximation.allow_mask_duplicatesBaseShapApproximation.embedding_modelBaseShapApproximation.embedding_reducerBaseShapApproximation.fractionBaseShapApproximation.include_minimal_masksBaseShapApproximation.modeBaseShapApproximation.normalizerBaseShapApproximation.num_samplesBaseShapApproximation.similarity_measureBaseShapApproximation.total_n_calls
- mllm_shap.shap.base.complementary module
BaseComplementaryShapApproximationBaseComplementaryShapApproximation.allow_mask_duplicatesBaseComplementaryShapApproximation.embedding_modelBaseComplementaryShapApproximation.embedding_reducerBaseComplementaryShapApproximation.fractionBaseComplementaryShapApproximation.include_minimal_masksBaseComplementaryShapApproximation.modeBaseComplementaryShapApproximation.normalizerBaseComplementaryShapApproximation.num_samplesBaseComplementaryShapApproximation.similarity_measureBaseComplementaryShapApproximation.total_n_calls
- mllm_shap.shap.base.embeddings module
- mllm_shap.shap.base.explainer module
- mllm_shap.shap.base.normalizers module
- mllm_shap.shap.base.shap_explainer module
- mllm_shap.shap.base.similarity module
- Module contents
- mllm_shap.shap.complementary package
- Submodules
- mllm_shap.shap.complementary.limited module
LimitedComplementaryShapExplainerLimitedComplementaryShapExplainer.allow_mask_duplicatesLimitedComplementaryShapExplainer.embedding_modelLimitedComplementaryShapExplainer.embedding_reducerLimitedComplementaryShapExplainer.fractionLimitedComplementaryShapExplainer.include_minimal_masksLimitedComplementaryShapExplainer.modeLimitedComplementaryShapExplainer.normalizerLimitedComplementaryShapExplainer.num_samplesLimitedComplementaryShapExplainer.similarity_measureLimitedComplementaryShapExplainer.total_n_calls
- mllm_shap.shap.complementary.standard module
StandardComplementaryShapExplainerStandardComplementaryShapExplainer.allow_mask_duplicatesStandardComplementaryShapExplainer.embedding_modelStandardComplementaryShapExplainer.embedding_reducerStandardComplementaryShapExplainer.fractionStandardComplementaryShapExplainer.include_minimal_masksStandardComplementaryShapExplainer.modeStandardComplementaryShapExplainer.normalizerStandardComplementaryShapExplainer.num_samplesStandardComplementaryShapExplainer.similarity_measureStandardComplementaryShapExplainer.total_n_calls
- Module contents
LimitedComplementaryShapExplainerLimitedComplementaryShapExplainer.allow_mask_duplicatesLimitedComplementaryShapExplainer.embedding_modelLimitedComplementaryShapExplainer.embedding_reducerLimitedComplementaryShapExplainer.fractionLimitedComplementaryShapExplainer.include_minimal_masksLimitedComplementaryShapExplainer.modeLimitedComplementaryShapExplainer.normalizerLimitedComplementaryShapExplainer.num_samplesLimitedComplementaryShapExplainer.similarity_measureLimitedComplementaryShapExplainer.total_n_calls
StandardComplementaryShapExplainerStandardComplementaryShapExplainer.allow_mask_duplicatesStandardComplementaryShapExplainer.embedding_modelStandardComplementaryShapExplainer.embedding_reducerStandardComplementaryShapExplainer.fractionStandardComplementaryShapExplainer.include_minimal_masksStandardComplementaryShapExplainer.modeStandardComplementaryShapExplainer.normalizerStandardComplementaryShapExplainer.num_samplesStandardComplementaryShapExplainer.similarity_measureStandardComplementaryShapExplainer.total_n_calls
- mllm_shap.shap.hierarchical package
- Submodules
- mllm_shap.shap.hierarchical.enums module
- mllm_shap.shap.hierarchical.explainer module
HierarchicalExplainerHierarchicalExplainer.computation_graphHierarchicalExplainer.first_layer_explainerHierarchicalExplainer.importance_sampling_min_fractionHierarchicalExplainer.kHierarchicalExplainer.modeHierarchicalExplainer.modelHierarchicalExplainer.n_callsHierarchicalExplainer.shap_explainerHierarchicalExplainer.total_n_callsHierarchicalExplainer.use_importance_sampling
- Module contents
HierarchicalExplainerHierarchicalExplainer.computation_graphHierarchicalExplainer.first_layer_explainerHierarchicalExplainer.importance_sampling_min_fractionHierarchicalExplainer.kHierarchicalExplainer.modeHierarchicalExplainer.modelHierarchicalExplainer.n_callsHierarchicalExplainer.shap_explainerHierarchicalExplainer.total_n_callsHierarchicalExplainer.use_importance_sampling
- mllm_shap.shap.monte_carlo package
- Submodules
- mllm_shap.shap.monte_carlo.limited module
LimitedMcShapExplainerLimitedMcShapExplainer.allow_mask_duplicatesLimitedMcShapExplainer.embedding_modelLimitedMcShapExplainer.embedding_reducerLimitedMcShapExplainer.fractionLimitedMcShapExplainer.include_minimal_masksLimitedMcShapExplainer.modeLimitedMcShapExplainer.normalizerLimitedMcShapExplainer.num_samplesLimitedMcShapExplainer.similarity_measureLimitedMcShapExplainer.total_n_calls
- mllm_shap.shap.monte_carlo.standard module
StandardMcShapExplainerStandardMcShapExplainer.allow_mask_duplicatesStandardMcShapExplainer.embedding_modelStandardMcShapExplainer.embedding_reducerStandardMcShapExplainer.fractionStandardMcShapExplainer.include_minimal_masksStandardMcShapExplainer.modeStandardMcShapExplainer.normalizerStandardMcShapExplainer.num_samplesStandardMcShapExplainer.similarity_measureStandardMcShapExplainer.total_n_calls
- mllm_shap.shap.monte_carlo.utils module
- Module contents
LimitedMcShapExplainerLimitedMcShapExplainer.allow_mask_duplicatesLimitedMcShapExplainer.embedding_modelLimitedMcShapExplainer.embedding_reducerLimitedMcShapExplainer.fractionLimitedMcShapExplainer.include_minimal_masksLimitedMcShapExplainer.modeLimitedMcShapExplainer.normalizerLimitedMcShapExplainer.num_samplesLimitedMcShapExplainer.similarity_measureLimitedMcShapExplainer.total_n_calls
StandardMcShapExplainerStandardMcShapExplainer.allow_mask_duplicatesStandardMcShapExplainer.embedding_modelStandardMcShapExplainer.embedding_reducerStandardMcShapExplainer.fractionStandardMcShapExplainer.include_minimal_masksStandardMcShapExplainer.modeStandardMcShapExplainer.normalizerStandardMcShapExplainer.num_samplesStandardMcShapExplainer.similarity_measureStandardMcShapExplainer.total_n_calls
approximate_budget()
Submodules#
mllm_shap.shap.compact module#
Compact SHAP explainer implementation.
- class mllm_shap.shap.compact.Explainer(shap_explainer: BaseShapExplainer | None = None, **kwargs: Any)[source]#
Bases:
BaseExplainerConvenience client class for SHAP explanation.
It generates the full response from the model and then uses the provided SHAP explainer to compute SHAP values.
Uses
PreciseShapExplaineras the default SHAP explainer.- model: BaseMllmModel#
The model connector instance.
- shap_explainer: BaseShapExplainer#
The SHAP explainer instance.
- total_n_calls: int = 0#
Total number of MLLM calls made for last explanation.
mllm_shap.shap.embeddings module#
Embedding calculation and reduction strategies for SHAP explanations.
- class mllm_shap.shap.embeddings.CustomEmbedding(*, generation_tokenizer: PreTrainedTokenizerBase, embed_model_id: str, embed_revision: str, device: device, max_length: int = 64, batch_size: int = 64, l2_normalize: bool = True, local_files_only: bool = True)[source]#
Bases:
BaseExternalEmbeddingExternal embeddings using a local encoder model (e.g., E5/SBERT).
- For each
ModelResponse, we: take
generated_text_tokens(shape [T]),decode each token id with the generation tokenizer to a short text piece,
embed each piece independently with the embedding encoder,
return a tensor of shape [T, hidden] per response (aligned 1:1 with tokens).
- batch_size: int#
- device: device#
- emb_model: PreTrainedModel#
- emb_tokenizer: PreTrainedTokenizerBase#
- l2_normalize: bool#
- max_length: int#
- tokenizer_decode: PreTrainedTokenizerBase#
- For each
- class mllm_shap.shap.embeddings.FirstReducer(n: int | None = None)[source]#
Bases:
BaseEmbeddingReducerReducer that selects the first embedding.
nparameter is ignored in this reducer.- n: int | None#
Maximum number of embeddings to sample before reduction. None means no sampling.
- class mllm_shap.shap.embeddings.MaxReducer(n: int | None = None)[source]#
Bases:
BaseEmbeddingReducerReducer that computes the max of embeddings.
- n: int | None#
Maximum number of embeddings to sample before reduction. None means no sampling.
- class mllm_shap.shap.embeddings.MeanReducer(n: int | None = None)[source]#
Bases:
BaseEmbeddingReducerReducer that computes the mean of embeddings.
- n: int | None#
Maximum number of embeddings to sample before reduction. None means no sampling.
- class mllm_shap.shap.embeddings.MinReducer(n: int | None = None)[source]#
Bases:
BaseEmbeddingReducerReducer that computes the min of embeddings.
- n: int | None#
Maximum number of embeddings to sample before reduction. None means no sampling.
- class mllm_shap.shap.embeddings.SumReducer(n: int | None = None)[source]#
Bases:
BaseEmbeddingReducerReducer that computes the sum of embeddings.
- n: int | None#
Maximum number of embeddings to sample before reduction. None means no sampling.
- class mllm_shap.shap.embeddings.ZeroReducer(n: int | None = None)[source]#
Bases:
BaseEmbeddingReducerDummy reducer that returns embeddings unchanged.
- n: int | None#
Maximum number of embeddings to sample before reduction. None means no sampling.
mllm_shap.shap.enums module#
Configuration for SHAP.
- class mllm_shap.shap.enums.Mode(*values)[source]#
Bases:
str,EnumPossible modes.
- CONTEXTUAL = 'contextual'#
Contextual mode - embeddings are computed using model internal states, therefore they carry contextual information.
- STATIC = 'static'#
Static mode - embeddings are computed using final model response tokens only, therefore they do not carry contextual information.
- capitalize()#
Return a capitalized version of the string.
More specifically, make the first character have upper case and the rest lower case.
- casefold()#
Return a version of the string suitable for caseless comparisons.
- center(width, fillchar=' ', /)#
Return a centered string of length width.
Padding is done using the specified fill character (default is a space).
- count(sub[, start[, end]]) int#
Return the number of non-overlapping occurrences of substring sub in string S[start:end]. Optional arguments start and end are interpreted as in slice notation.
- encode(encoding='utf-8', errors='strict')#
Encode the string using the codec registered for encoding.
- encoding
The encoding in which to encode the string.
- errors
The error handling scheme to use for encoding errors. The default is ‘strict’ meaning that encoding errors raise a UnicodeEncodeError. Other possible values are ‘ignore’, ‘replace’ and ‘xmlcharrefreplace’ as well as any other name registered with codecs.register_error that can handle UnicodeEncodeErrors.
- endswith(suffix[, start[, end]]) bool#
Return True if S ends with the specified suffix, False otherwise. With optional start, test S beginning at that position. With optional end, stop comparing S at that position. suffix can also be a tuple of strings to try.
- expandtabs(tabsize=8)#
Return a copy where all tab characters are expanded using spaces.
If tabsize is not given, a tab size of 8 characters is assumed.
- find(sub[, start[, end]]) int#
Return the lowest index in S where substring sub is found, such that sub is contained within S[start:end]. Optional arguments start and end are interpreted as in slice notation.
Return -1 on failure.
- format(*args, **kwargs) str#
Return a formatted version of S, using substitutions from args and kwargs. The substitutions are identified by braces (‘{’ and ‘}’).
- format_map(mapping) str#
Return a formatted version of S, using substitutions from mapping. The substitutions are identified by braces (‘{’ and ‘}’).
- index(sub[, start[, end]]) int#
Return the lowest index in S where substring sub is found, such that sub is contained within S[start:end]. Optional arguments start and end are interpreted as in slice notation.
Raises ValueError when the substring is not found.
- isalnum()#
Return True if the string is an alpha-numeric string, False otherwise.
A string is alpha-numeric if all characters in the string are alpha-numeric and there is at least one character in the string.
- isalpha()#
Return True if the string is an alphabetic string, False otherwise.
A string is alphabetic if all characters in the string are alphabetic and there is at least one character in the string.
- isascii()#
Return True if all characters in the string are ASCII, False otherwise.
ASCII characters have code points in the range U+0000-U+007F. Empty string is ASCII too.
- isdecimal()#
Return True if the string is a decimal string, False otherwise.
A string is a decimal string if all characters in the string are decimal and there is at least one character in the string.
- isdigit()#
Return True if the string is a digit string, False otherwise.
A string is a digit string if all characters in the string are digits and there is at least one character in the string.
- isidentifier()#
Return True if the string is a valid Python identifier, False otherwise.
Call keyword.iskeyword(s) to test whether string s is a reserved identifier, such as “def” or “class”.
- islower()#
Return True if the string is a lowercase string, False otherwise.
A string is lowercase if all cased characters in the string are lowercase and there is at least one cased character in the string.
- isnumeric()#
Return True if the string is a numeric string, False otherwise.
A string is numeric if all characters in the string are numeric and there is at least one character in the string.
- isprintable()#
Return True if all characters in the string are printable, False otherwise.
A character is printable if repr() may use it in its output.
- isspace()#
Return True if the string is a whitespace string, False otherwise.
A string is whitespace if all characters in the string are whitespace and there is at least one character in the string.
- istitle()#
Return True if the string is a title-cased string, False otherwise.
In a title-cased string, upper- and title-case characters may only follow uncased characters and lowercase characters only cased ones.
- isupper()#
Return True if the string is an uppercase string, False otherwise.
A string is uppercase if all cased characters in the string are uppercase and there is at least one cased character in the string.
- join(iterable, /)#
Concatenate any number of strings.
The string whose method is called is inserted in between each given string. The result is returned as a new string.
Example: ‘.’.join([‘ab’, ‘pq’, ‘rs’]) -> ‘ab.pq.rs’
- ljust(width, fillchar=' ', /)#
Return a left-justified string of length width.
Padding is done using the specified fill character (default is a space).
- lower()#
Return a copy of the string converted to lowercase.
- lstrip(chars=None, /)#
Return a copy of the string with leading whitespace removed.
If chars is given and not None, remove characters in chars instead.
- static maketrans()#
Return a translation table usable for str.translate().
If there is only one argument, it must be a dictionary mapping Unicode ordinals (integers) or characters to Unicode ordinals, strings or None. Character keys will be then converted to ordinals. If there are two arguments, they must be strings of equal length, and in the resulting dictionary, each character in x will be mapped to the character at the same position in y. If there is a third argument, it must be a string, whose characters will be mapped to None in the result.
- partition(sep, /)#
Partition the string into three parts using the given separator.
This will search for the separator in the string. If the separator is found, returns a 3-tuple containing the part before the separator, the separator itself, and the part after it.
If the separator is not found, returns a 3-tuple containing the original string and two empty strings.
- removeprefix(prefix, /)#
Return a str with the given prefix string removed if present.
If the string starts with the prefix string, return string[len(prefix):]. Otherwise, return a copy of the original string.
- removesuffix(suffix, /)#
Return a str with the given suffix string removed if present.
If the string ends with the suffix string and that suffix is not empty, return string[:-len(suffix)]. Otherwise, return a copy of the original string.
- replace(old, new, count=-1, /)#
Return a copy with all occurrences of substring old replaced by new.
- count
Maximum number of occurrences to replace. -1 (the default value) means replace all occurrences.
If the optional argument count is given, only the first count occurrences are replaced.
- rfind(sub[, start[, end]]) int#
Return the highest index in S where substring sub is found, such that sub is contained within S[start:end]. Optional arguments start and end are interpreted as in slice notation.
Return -1 on failure.
- rindex(sub[, start[, end]]) int#
Return the highest index in S where substring sub is found, such that sub is contained within S[start:end]. Optional arguments start and end are interpreted as in slice notation.
Raises ValueError when the substring is not found.
- rjust(width, fillchar=' ', /)#
Return a right-justified string of length width.
Padding is done using the specified fill character (default is a space).
- rpartition(sep, /)#
Partition the string into three parts using the given separator.
This will search for the separator in the string, starting at the end. If the separator is found, returns a 3-tuple containing the part before the separator, the separator itself, and the part after it.
If the separator is not found, returns a 3-tuple containing two empty strings and the original string.
- rsplit(sep=None, maxsplit=-1)#
Return a list of the substrings in the string, using sep as the separator string.
- sep
The separator used to split the string.
When set to None (the default value), will split on any whitespace character (including n r t f and spaces) and will discard empty strings from the result.
- maxsplit
Maximum number of splits. -1 (the default value) means no limit.
Splitting starts at the end of the string and works to the front.
- rstrip(chars=None, /)#
Return a copy of the string with trailing whitespace removed.
If chars is given and not None, remove characters in chars instead.
- split(sep=None, maxsplit=-1)#
Return a list of the substrings in the string, using sep as the separator string.
- sep
The separator used to split the string.
When set to None (the default value), will split on any whitespace character (including n r t f and spaces) and will discard empty strings from the result.
- maxsplit
Maximum number of splits. -1 (the default value) means no limit.
Splitting starts at the front of the string and works to the end.
Note, str.split() is mainly useful for data that has been intentionally delimited. With natural text that includes punctuation, consider using the regular expression module.
- splitlines(keepends=False)#
Return a list of the lines in the string, breaking at line boundaries.
Line breaks are not included in the resulting list unless keepends is given and true.
- startswith(prefix[, start[, end]]) bool#
Return True if S starts with the specified prefix, False otherwise. With optional start, test S beginning at that position. With optional end, stop comparing S at that position. prefix can also be a tuple of strings to try.
- strip(chars=None, /)#
Return a copy of the string with leading and trailing whitespace removed.
If chars is given and not None, remove characters in chars instead.
- swapcase()#
Convert uppercase characters to lowercase and lowercase characters to uppercase.
- title()#
Return a version of the string where each word is titlecased.
More specifically, words start with uppercased characters and all remaining cased characters have lower case.
- translate(table, /)#
Replace each character in the string using the given translation table.
- table
Translation table, which must be a mapping of Unicode ordinals to Unicode ordinals, strings, or None.
The table must implement lookup/indexing via __getitem__, for instance a dictionary or list. If this operation raises LookupError, the character is left untouched. Characters mapped to None are deleted.
- upper()#
Return a copy of the string converted to uppercase.
- zfill(width, /)#
Pad a numeric string with zeros on the left, to fill a field of the given width.
The string is never truncated.
mllm_shap.shap.explainer_result module#
Result model for high-level SHAP explainers.
- class mllm_shap.shap.explainer_result.ExplainerResult(*, full_chat: BaseMllmChat, source_chat: BaseMllmChat, history: list[tuple[Tensor, int, BaseMllmChat | None, ModelResponse]] | None, total_n_calls: int = 0)[source]#
Bases:
BaseModelResult model for Explainer.
- full_chat: BaseMllmChat#
The full chat instance after generation (entire conversation). It will be set with SHAP values and cache.
- history: list[tuple[Tensor, int, BaseMllmChat | None, ModelResponse]] | None#
The history of chats and masks used during explanation (if applicable, that is if explainer was called with verbose=True). Each entry is a tuple of (mask, mask_hash, masked_chat, model_response) If cache was used, masked_chat will be None.).
- model_config = {'arbitrary_types_allowed': True}#
Configuration for pydantic model.
- source_chat: BaseMllmChat#
Chat to get explained (without base response).
- total_n_calls: int#
Total number of MLLM calls made for last explanation.
mllm_shap.shap.neyman module#
Neyman SHAP explainers module.
LimitedComplementaryNeymanShapExplainerimplements a limited Neyman samplingthat samples initial masks of given size with pre-defined member.
StandardComplementaryNeymanShapExplainerdoes standard Neyman samplingthat samples initial masks of given size randomly.
- class mllm_shap.shap.neyman.LimitedComplementaryNeymanShapExplainer(*args: Any, initial_num_samples: int | None = None, initial_fraction: float | None = None, **kwargs: Any)[source]#
Bases:
BaseComplementaryNeymanShapExplainerLimited Neyman SHAP implementation.
- allow_mask_duplicates: bool#
Whether to allow duplicate masks during generation.
- embedding_model: BaseExternalEmbedding | None#
The external embedding model to use. If provided, overrides
mode.
- embedding_reducer: BaseEmbeddingReducer#
The embedding reduction strategy to use.
- fraction: float | None#
Fraction of total possible masks to generate if num_samples is None.
- include_minimal_masks: bool = True#
Whether to include minimal masks (single-feature and empty masks) in the sampling.
- initial_fraction: float | None#
Initial fraction of samples to draw in the first step.
- initial_num_samples: int | None#
Initial number of samples to draw in the first step.
- initial_steps: int | None#
Number of initial steps performed in last call.
- mode: Mode#
The SHAP mode, either STATIC or CONTEXTUAL. Used if no
embedding_modelis provided.
- normalizer: BaseNormalizer#
The SHAP value normalizer to use.
- num_samples: int | None#
Number of random masks to generate. If None, uses fraction. -1 stands for minimal number of samples (only single-feature masks and empty mask).
- similarity_measure: BaseEmbeddingSimilarity#
The embedding similarity measure to use.
- total_n_calls: int = 0#
Total number of MLLM calls made for last explanation.
- use_standard_method: bool = False#
Whether to use the standard method for initial sampling. Default is False, which uses the modified method with pre-defined members.
- class mllm_shap.shap.neyman.StandardComplementaryNeymanShapExplainer(*args: Any, initial_num_samples: int | None = None, initial_fraction: float | None = None, **kwargs: Any)[source]#
Bases:
BaseComplementaryNeymanShapExplainerStandard Neyman SHAP Explainer.
- allow_mask_duplicates: bool#
Whether to allow duplicate masks during generation.
- embedding_model: BaseExternalEmbedding | None#
The external embedding model to use. If provided, overrides
mode.
- embedding_reducer: BaseEmbeddingReducer#
The embedding reduction strategy to use.
- fraction: float | None#
Fraction of total possible masks to generate if num_samples is None.
- include_minimal_masks: bool = True#
Whether to include minimal masks (single-feature and empty masks) in the sampling.
- initial_fraction: float | None#
Initial fraction of samples to draw in the first step.
- initial_num_samples: int | None#
Initial number of samples to draw in the first step.
- initial_steps: int | None#
Number of initial steps performed in last call.
- mode: Mode#
The SHAP mode, either STATIC or CONTEXTUAL. Used if no
embedding_modelis provided.
- normalizer: BaseNormalizer#
The SHAP value normalizer to use.
- num_samples: int | None#
Number of random masks to generate. If None, uses fraction. -1 stands for minimal number of samples (only single-feature masks and empty mask).
- similarity_measure: BaseEmbeddingSimilarity#
The embedding similarity measure to use.
- total_n_calls: int = 0#
Total number of MLLM calls made for last explanation.
- use_standard_method: bool = True#
Whether to use the standard method for initial sampling. Default is False, which uses the modified method with pre-defined members.
mllm_shap.shap.normalizers module#
Normalizers for SHAP values.
- class mllm_shap.shap.normalizers.AbsSumNormalizer[source]#
Bases:
BaseNormalizerNormalizer that scales SHAP values by the sum of their absolute values.
- class mllm_shap.shap.normalizers.IdentityNormalizer[source]#
Bases:
BaseNormalizerNormalizer that returns SHAP values unchanged.
- class mllm_shap.shap.normalizers.MinMaxNormalizer[source]#
Bases:
BaseNormalizerNormalizer that scales SHAP values to the [0, 1] range using min-max normalization then normalizes them to sum to 1.
- class mllm_shap.shap.normalizers.PowerShiftNormalizer(power: float = 1.0)[source]#
Bases:
BaseNormalizerNormalizer that applies power shift normalization to SHAP values.
- power: float#
The power to which SHAP values are raised.
mllm_shap.shap.precise module#
Precise SHAP explainer implementation.
- class mllm_shap.shap.precise.PreciseShapExplainer(mode: Mode = Mode.CONTEXTUAL, embedding_model: BaseExternalEmbedding | None = None, embedding_reducer: BaseEmbeddingReducer | None = None, similarity_measure: BaseEmbeddingSimilarity | None = None, normalizer: BaseNormalizer | None = None, allow_mask_duplicates: bool = False)[source]#
Bases:
BaseShapExplainerPrecise SHAP implementation generating all possible masks.
- allow_mask_duplicates: bool#
Whether to allow duplicate masks during generation.
- embedding_model: BaseExternalEmbedding | None#
The external embedding model to use. If provided, overrides
mode.
- embedding_reducer: BaseEmbeddingReducer#
The embedding reduction strategy to use.
- mode: Mode#
The SHAP mode, either STATIC or CONTEXTUAL. Used if no
embedding_modelis provided.
- normalizer: BaseNormalizer#
The SHAP value normalizer to use.
- similarity_measure: BaseEmbeddingSimilarity#
The embedding similarity measure to use.
- total_n_calls: int = 0#
Total number of MLLM calls made for last explanation.
mllm_shap.shap.similarity module#
Embedding similarity calculations for SHAP explanations.
- class mllm_shap.shap.similarity.CosineSimilarity[source]#
Bases:
BaseEmbeddingSimilarityCosine similarity calculation, used in implementation of U1 and U2 utility functions from the paper.
- operates_on_embeddings: bool = True#
Indicates that the similarity operates on embeddings. If False, it operates on raw tokens.
Used to resolve input to
__call__().
- class mllm_shap.shap.similarity.EuclideanSimilarity[source]#
Bases:
BaseEmbeddingSimilarityEuclidean similarity calculation, used in implementation of U4 utility function from the paper.
- operates_on_embeddings: bool = True#
Indicates that the similarity operates on embeddings. If False, it operates on raw tokens.
Used to resolve input to
__call__().
- class mllm_shap.shap.similarity.TfIdfCosineSimilarity[source]#
Bases:
BaseEmbeddingSimilarityTF-IDF weighted Cosine similarity calculation, used in implementation of U3 utility function from the paper.
- operates_on_embeddings: bool = False#
Indicates that the similarity operates on embeddings. If False, it operates on raw tokens.
Used to resolve input to
__call__().
Module contents#
SHAP explainers module.
PreciseShapExplainerimplements the precise SHAP value computationusing the original SHAP algorithm.
Explainerimplements the compact SHAP explainer that optimizesthe computation of SHAP values for large models and datasets.
McShapExplaineris an alias for the limited Monte Carlo SHAP explainer.ComplementaryShapExplainerimplements the complementary SHAP explainerthat focuses on explaining the complementary contributions of features. It is an alias for the limited complementary SHAP explainer.
ComplementaryNeymanShapExplainerimplements the complementary SHAPexplainer using Neyman allocation for improved sample efficiency. It is an alias for the limited complementary Neyman SHAP explainer.
HierarchicalExplainerimplements a hierarchical approach to SHAPvalue computation, allowing for significant speed-ups.
- mllm_shap.shap.ComplementaryNeymanShapExplainer#
- mllm_shap.shap.ComplementaryShapExplainer#
alias of
LimitedComplementaryShapExplainer
- class mllm_shap.shap.Explainer(shap_explainer: BaseShapExplainer | None = None, **kwargs: Any)[source]#
Bases:
BaseExplainerConvenience client class for SHAP explanation.
It generates the full response from the model and then uses the provided SHAP explainer to compute SHAP values.
Uses
PreciseShapExplaineras the default SHAP explainer.- model: BaseMllmModel#
The model connector instance.
- shap_explainer: BaseShapExplainer#
The SHAP explainer instance.
- total_n_calls: int = 0#
Total number of MLLM calls made for last explanation.
- class mllm_shap.shap.HierarchicalExplainer(shap_explainer: BaseShapExplainer | None = None, first_layer_explainer: BaseShapExplainer | None = None, mode: Mode = Mode.TEXT, k: int = 10, use_importance_sampling: bool = False, importance_sampling_min_fraction: float = 0.1, **kwargs: Any)[source]#
Bases:
BaseExplainerSHAP explainer implementing hierarchical approach for speed-up.
Groups are divided into subgroups recursively until the final group size. Groups cannot share different modalities (e.g., text and audio tokens). Uses an underlying SHAP explainer for group explanations.
Should be used with SHAP explainers that normalize using
MinMaxNormalizer.It has no history nor non-normalized shap values available. Refer to
Modefor details on how groups are formed at the first level.- computation_graph: GraphNode | None#
Computation graph for the last explanation.
- first_layer_explainer: BaseShapExplainer | None#
The SHAP explainer instance for the first layer. If provided, first layer explanation will be done using this explainer - it will be fitted with all explainable tokens. Calculated SV will be summed per group and used as first layer SHAP values.
- importance_sampling_min_fraction: float#
Minimum fraction for importance sampling.
- k: int#
Maximum final group size at each level.
- model: BaseMllmModel#
The model connector instance.
- n_calls: int#
Number of internal SHAP explainer calls made for last explanation.
- shap_explainer: BaseShapExplainer#
The SHAP explainer instance.
- total_n_calls: int = 0#
Total number of MLLM calls made for last explanation.
- use_importance_sampling: bool#
Whether to use importance for setting sampling budget (for each group).
- mllm_shap.shap.McShapExplainer#
alias of
LimitedMcShapExplainer
- class mllm_shap.shap.PreciseShapExplainer(mode: Mode = Mode.CONTEXTUAL, embedding_model: BaseExternalEmbedding | None = None, embedding_reducer: BaseEmbeddingReducer | None = None, similarity_measure: BaseEmbeddingSimilarity | None = None, normalizer: BaseNormalizer | None = None, allow_mask_duplicates: bool = False)[source]#
Bases:
BaseShapExplainerPrecise SHAP implementation generating all possible masks.
- allow_mask_duplicates: bool#
Whether to allow duplicate masks during generation.
- embedding_model: BaseExternalEmbedding | None#
The external embedding model to use. If provided, overrides
mode.
- embedding_reducer: BaseEmbeddingReducer#
The embedding reduction strategy to use.
- mode: Mode#
The SHAP mode, either STATIC or CONTEXTUAL. Used if no
embedding_modelis provided.
- normalizer: BaseNormalizer#
The SHAP value normalizer to use.
- similarity_measure: BaseEmbeddingSimilarity#
The embedding similarity measure to use.
- total_n_calls: int = 0#
Total number of MLLM calls made for last explanation.