mllm_shap.connectors.base package#

Submodules#

mllm_shap.connectors.base.chat module#

Base class for chat state management.

exception mllm_shap.connectors.base.chat.AllTextTokensFilteredOutError[source]#

Bases: ValueError

Raised when all text tokens are filtered out from the chat.

add_note()#

Exception.add_note(note) – add a note to the exception

args#
with_traceback()#

Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.

class mllm_shap.connectors.base.chat.BaseMllmChat(device: device, empty_turn_sequences: set[str], get_new_chat_callable: Callable[[...], BaseMllmChat], token_filter: TokenFilter | None = None, system_roles_setup: SystemRolesSetup | None = None)[source]#

Bases: ABC

Base class for chat state management.

Important: Audio tokens are always added for shap calculations.

When using audio segments, make sure they are only message within their respective turns.

add_audio(audio_content: bytes, audio_format: str = 'mp3', _waveform: Tensor | None = None, _sample_rate: int | None = None, _internal: bool = False) None[source]#

Add audio content to the chat state.

Parameters:
  • audio_content – The audio content in bytes.

  • audio_format – The format of the audio content (default is “mp3”).

Raises:
  • ValueError – If audio_content is not non-empty bytes.

  • RuntimeError – If an error occurs in the underlying connector implementation.

add_audio_with_transcript(audio_content: bytes, transcript: str | list[str], aligner: SpectrogramGuidedAligner, audio_format: str = 'mp3', attach_audio: bool = False) None[source]#

Add audio content along with its transcript to the chat state.

Parameters:
  • audio_content – The audio content in bytes.

  • transcript – The transcript of the audio content.

  • aligner – The SpectrogramGuidedAligner instance for aligning audio and transcript.

  • audio_format – The format of the audio content (default is “mp3”).

  • attach_audio – Whether to attach raw audio bytes to each segment (default is False). If False, segments will have empty audio bytes to save memory.

Raises:
  • ValueError – If audio_content is not non-empty bytes or transcript is not a non-empty string.

  • RuntimeError – If an error occurs in the underlying connector implementation.

add_text(text: str) None[source]#

Add text to the chat state.

Parameters:

text – The text to add.

Raises:
  • ValueError – If text is not a non-empty string.

  • RuntimeError – If an error occurs in the underlying connector implementation.

append(text: Tensor, audio_out: Tensor, modality_flag: Tensor, history_tracking_mode: ModelHistoryTrackingMode) None[source]#

Append text and audio tokens along with modality flags to the chat state. Assumes that entry data is correct and non system.

Warning

This method does not validate input data.

Parameters:
  • text – The text tokens to append.

  • audio_out – The audio tokens to append (intended for model’s output).

  • modality_flag – The modality flags corresponding to the tokens.

  • history_tracking_mode – The mode for tracking chat history.

Raises:
  • ValueError – If length mismatch occurs after appending.

  • RuntimeError – If an error occurs in the underlying connector implementation.

attach_audio_to_segments(aligner: SpectrogramGuidedAligner, audio_content: bytes, audio_format: str = 'mp3', turn_number: int | None = None) None[source]#

Attach raw audio bytes to segments for a specific turn.

This is a helper method to materialize audio bytes for segments that were created without audio attachment (i.e., with attach_audio=False).

Parameters:
  • aligner – The SpectrogramGuidedAligner instance to use for attaching audio.

  • audio_content – The audio content in bytes.

  • audio_format – The format of the audio content (default is “mp3”).

  • turn_number – The turn number to attach audio for. If None, uses the current turn.

Raises:

ValueError – If no audio segments exist for the specified turn.

property audio_tokens: Tensor#

Input audio tensor (tokens) in shape (T, K)

property audio_tokens_mask: Tensor#

Boolean mask indicating positions of audio tokens in the input (input_tokens tensor).

audio_tokens_no_system_mask: Tensor#

A boolean tensor indicating which audio tokens are user-generated.

property audio_tokens_no_system_mask_filtered: Tensor#

Boolean mask indicating which audio tokens are not system generated. Relative to audio_tokens_mask.

property cache: ExplainerCache | None#

Access for the explainer cache if set.

decode_audio(audio_tokens: list[Tensor] | Tensor | None = None, sample_rate: int = 24000, audio_format: str = 'mp3') bytes[source]#

Decode the generated audio tokens.

Warning

This method does not validate input data.

Parameters:
  • audio_tokens – Audio tokens to decode in format (T, K).

  • sample_rate – The sample rate for the decoded audio (default is 24,000 Hz).

  • audio_format – The desired output audio format (default is “mp3”).

Returns:

The decoded audio content in bytes. Empty bytes if decoding is not available.

Raises:

RuntimeError – If an error occurs in the underlying connector implementation.

decode_text(text_tokens: list[Tensor] | Tensor | None = None) str[source]#

Decode the generated text tokens.

Warning

This method does not validate input data.

Parameters:

text_tokens – The generated text tokens.

Returns:

The decoded text.

Raises:

RuntimeError – If an error occurs in the underlying connector implementation.

empty_turn_sequences: list[Tensor]#

A tensor indicating the empty turn sequence.

end_turn() None[source]#

End the current turn in the chat state.

Warning

For Developers: This method assumes cached property refresh is handled in _new_turn() or by calling add_text/add_audio methods.

Raises:
  • ValueError – If no turn is active.

  • RuntimeError – If an error occurs in the underlying connector implementation.

property external_group_ids: Tensor | None#

An optional external group IDs for tokens. Should be an integer tensor with size equal to the number of tokens in the chat and set directly before the explanation process.

All entries > 0 will be treated as belonging to the same group for SHAP value calculations, 0 entries will be ignored from calculations (shap_values_mask for them will be set to False). Takes precedence over external_shap_values_mask.

Forbids adding new tokens to the chat while set.

property external_group_ids_first_positions: Tensor#

Get the positions (indices) of the first occurrences of each consecutive non-zero group ID in external_group_ids.

Returns:

Tensor of positions (indices) of the first occurrence of each non-zero group ID, or None if not set.

Raises:

ValueError – If external_group_ids is not set.

property external_group_ids_positive_mask: Tensor#

Boolean mask indicating which tokens have positive group IDs.

Returns:

Boolean mask indicating which tokens have positive group IDs.

Raises:

ValueError – If external_group_ids is not set.

property external_shap_values_mask: Tensor | None#

An optional external SHAP values mask. Should be a boolean tensor with size equal to the number of tokens in the chat.

If provided, shap_values_mask will be and-ed with this mask.

Forbids adding new tokens to the chat while set.

classmethod from_chat(mask: Tensor, chat: BaseMllmChat) BaseMllmChat[source]#

Create a new chat instance from an existing chat and a mask.

Parameters:
  • mask – A boolean tensor indicating which messages to include.

  • chat – The existing chat instance to copy.

Returns:

An instance of BaseMllmChat.

Raises:
  • ValueError – If the mask size does not match the number of tokens in the chat, or if the mask is all False, or if all text tokens are filtered out.

  • RuntimeError – If audio segments are inconsistent with stored waveforms.

get_conversation() list[list[ChatEntry]][source]#

Serialize the chat state to a JSON-compatible dictionary.

Returns:

A list of turns, where each turn is a list of ChatEntry objects. shap_values=None

indicates that SHAP values are not yet available.

Raises:

NotImplementedError – If the chat contains audio segments.

Example:

[
    [
        ChatEntry(
            content_type=0,
            roles=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
            content='<|im_start|>, system, \n, You, are, a helpful assistant that answers questions briefly...', # noqa: E501 # pylint: disable=line-too-long
            shap_values=None
        )
    ],
    [
        ChatEntry(
            content_type=0,
            roles=[2, 2, 2, 0, 0, 0, 0, 2, 2],
            content='<|im_start|>, user, \n, Who, are, you, ?, <|im_end|>, \n',
            shap_values=None
        )
    ]
]
get_new_chat_callable: Callable[[...], BaseMllmChat]#

A callable to create a new chat instance.

property input_tokens: list[Tensor]#

Combined input tensor (text + audio).

property input_tokens_num: int#

Total number of input tokens (text + audio).

property is_system_turn: bool#

Flag indicating whether the current turn is a system turn.

new_turn(speaker: Role) None[source]#

Start a new turn in the chat state.

Warning

For Developers: This method assumes cached property refresh is handled in _new_turn() or by calling add_text/add_audio methods.

Parameters:

speaker – The role of the speaker for the new turn.

Raises:
  • ValueError – If a turn is already active.

  • RuntimeError – If an error occurs in the underlying connector implementation.

refresh(full: bool = False, shap: bool = False) None[source]#

Refresh cached properties, that is:

  • input_tokens

  • tokens_modality_flag

  • text_tokens_mask

  • text_tokens

  • audio_tokens_mask

  • audio_tokens

If full is True, also refresh:

  • audio_tokens_no_system_mask_filtered

  • text_tokens_no_system_mask_filtered

  • shap_values_mask

  • external_group_ids_first_positions

  • external_group_ids_positive_mask

If shap is True, will only refresh: - shap_values_mask - external_group_ids_first_positions - external_group_ids_positive_mask

Parameters:
  • full – If True, refreshes all cached properties.

  • shap – If True, refreshes shap-related cached properties.

property shap_values_mask: Tensor#

Boolean mask indicating which tokens should be considered for SHAP value calculations (i.e., non-system text tokens).

Raises:
  • ValueError – If the external SHAP values mask size does not match the number of tokens in the chat.

  • RuntimeError – If external_group_ids is set but has no positive IDs.

speaker: Role | None = None#

The role of the current speaker in the chat.

system_roles_setup: SystemRolesSetup#

A set of roles that are considered system roles. If Role.ASSISTANT is added, multi-turn assistant messages will also be marked as system and therefore excluded from shapley value calculations.

property text_tokens: Tensor#

Input text tensor (tokens).

property text_tokens_mask: Tensor#

Boolean mask indicating positions of text tokens in the input (input_tokens tensor).

text_tokens_no_system_mask: Tensor#

A boolean tensor indicating which text tokens are user-generated.

property text_tokens_no_system_mask_filtered: Tensor#

Boolean mask indicating which text tokens are not system generated, after filtering out specified sequences. Relative to text_tokens_mask.

token_filter: TokenFilter#

The token filtering strategy.

token_roles: Tensor#

A tensor indicating the role of each token in the chat.

token_sequences_to_exclude: list[Tensor]#

A list of token IDs to exclude from processing.

token_turns: Tensor#

A tensor indicating the turn structure of the chat.

property tokens_modality_flag: Tensor#

The modality flag tensor indicating token types according to ModalityFlag enum.

torch_device: device#

The device on which tensors are stored.

translate_groups_ids_mask(mask: Tensor) Tensor[source]#

Translate a mask over group IDs to a mask over all tokens.

Parameters:

mask – A boolean tensor indicating which group IDs to include.

Returns:

A boolean tensor indicating which tokens to include.

turn_number: int#

The current turn number in the chat.

mllm_shap.connectors.base.chat_entry module#

Conversation entry data structure for audio and text modalities.

class mllm_shap.connectors.base.chat_entry.ChatEntry(*, content_type: int, roles: list[int], content: list[str | bytes], shap_values: list[float | None] | None)[source]#

Bases: BaseModel

Conversation entry data structure.

content: list[str | bytes]#

Content of the entry, can be text (str) or audio bytes (bytes).

content_type: int#

Modality of the content (e.g., text or audio), refer to ModalityFlag.

display() None[source]#

Display the ChatEntry content.

Raises:

ValueError – If the number of roles does not match the number of content pieces.

Example:

BY: USER, SYSTEM
TEXT CONTENT:
    <|im_start|> user
    Who  are  you ? <|im_end|>
model_config = {'arbitrary_types_allowed': True}#

Configuration for pydantic model.

roles: list[int]#

List of roles associated with this entry, refer to Role.

shap_values: list[float | None] | None#

SHAP values associated with the content tokens, if one have been computed for this entry.

mllm_shap.connectors.base.explainer_cache module#

Cache for explainer computations.

class mllm_shap.connectors.base.explainer_cache.ExplainerCache(chat: BaseMllmChat, responses: list[ModelResponse], masks: Tensor, shap_values_mask: Tensor, *, calculated_by: int, n: int, had_different_masks: bool = False)[source]#

Bases: BaseModel

Cache for explainer computations associated with a chat. Saves and validates calculated SHAP values, masks, and reduced embeddings.

calculated_by: int#

Hash of the explainer that calculated the SHAP values.

chat: BaseMllmChat#

The chat instance the cache is for.

classmethod create(chat: BaseMllmChat, explainer_hash: int, responses: list[ModelResponse], masks: Tensor, normalized_values: Tensor, shap_values_mask: Tensor, values: Tensor | None = None) ExplainerCache[source]#

Create a new ExplainerCache instance.

Parameters:
  • chat – The chat instance the cache is for.

  • explainer_hash – Hash of the explainer that calculated the SHAP values.

  • responses – The model responses used for SHAP calculations.

  • masks – The masks used for SHAP calculations.

  • values – The SHAP values calculated.

  • normalized_values – The normalized SHAP values calculated.

  • shap_values_mask – The mask indicating which SHAP values are relevant.

Returns:

A new ExplainerCache instance.

extend_masks() None[source]#

Extend masks to match the chat length.

static extend_values(values: Tensor, shape: tuple[int, ...], dim: int, fill_value: Any, device: device) Tensor[source]#

Extend SHAP values to match the chat length.

Parameters:
  • values – The SHAP values to extend.

  • shape – The target shape for extension.

  • dim – The dimension along which to extend.

  • fill_value – The value to use for extension.

  • device – The device to create the extended tensor on.

Returns:

The extended SHAP values.

had_different_masks: bool#

Whether the masks used for SHAP calculations differed from chat’s masks.

masks: Tensor#

The masks used for SHAP calculations.

model_config = {'arbitrary_types_allowed': True}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_post_init(context: Any, /) None#

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that’s what pydantic-core passes when calling it.

Parameters:
  • self – The BaseModel instance.

  • context – The context.

n: int#

Index of last token used for SHAP calculations.

property normalized_values: Tensor#

Normalized SHAP values.

Raises:

ValueError – If SHAP values are no longer valid or have not been computed yet.

responses: list[ModelResponse]#

The model responses used for SHAP calculations.

shap_values_mask: Tensor#

The mask indicating which SHAP values are relevant.

property values: Tensor | None#

SHAP values. Can be none if HierarchicalExplainer is used.

Raises:

ValueError – If SHAP values are no longer valid or have not been computed yet.

mllm_shap.connectors.base.filters module#

Base classes for token filtering strategies.

class mllm_shap.connectors.base.filters.TokenFilter(*, phrases_to_exclude: set[str])[source]#

Bases: ABC, BaseModel

Base class for token filtering strategies.

model_config = {}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

phrases_to_exclude: set[str]#

Set of phrases to exclude from SHAP calculations.

mllm_shap.connectors.base.model module#

Base model connector class.

class mllm_shap.connectors.base.model.BaseMllmModel(config: HuggingFaceModelConfig, device: device, processor: Any, model: Any, history_tracking_mode: ModelHistoryTrackingMode = ModelHistoryTrackingMode.TEXT)[source]#

Bases: ABC

Base class for model connectors.

config: HuggingFaceModelConfig#

The model configuration.

device: device#

The device to run the model on.

abstractmethod generate(chat: BaseMllmChat, max_new_tokens: int = 128, model_config: ModelConfig = ModelConfig(text_temperature=0.0, text_top_k=1, audio_temperature=0.0, audio_top_k=1), keep_history: bool = False) ModelResponse[source]#

Generate audio based on the current chat state.

Parameters:
  • chat – The current chat state.

  • max_new_tokens – The maximum number of new tokens to generate (default is 20).

  • model_config – Additional model configuration parameters.

  • keep_history – Whether to return chat state with full history or only generated content.

Returns:

The updated chat state after generation.

Return type:

ModelResponse

get_contextual_embeddings(*args: Any, static_embeddings: list[Tensor] | None = None, **kwargs: Any) list[Tensor][source]#

Get contextual embeddings for the current chat state.

Parameters:
  • static_embeddings – Precomputed static embeddings (if any).

  • *args – Additional positional arguments for get_static_embeddings(). Used if static_embeddings is None.

  • **kwargs – Additional keyword arguments for get_static_embeddings(). Used if static_embeddings is None.

Returns:

The context embeddings for the text and audio tokens, same format as in

get_static_embeddings().

Raises:

ValueError – If static_embeddings is not an instance of Tensor.

abstractmethod get_new_chat() BaseMllmChat[source]#

Get a new chat state for the model.

abstractmethod get_static_embeddings(responses: list[ModelResponse]) list[Tensor][source]#

Get static embeddings for the current chat state.

Parameters:

responses – The model responses to get embeddings for.

Returns:

The static embeddings for the text and audio tokens.

Raises:

ValueError – If responses is not a list of ModelResponse.

history_tracking_mode: ModelHistoryTrackingMode#

The mode for tracking chat history.

model: Any#

The model instance.

processor: Any#

The model processor (tokenizer).

mllm_shap.connectors.base.model_response module#

Model response wrapper module.

class mllm_shap.connectors.base.model_response.ModelResponse(*, chat: BaseMllmChat | None, generated_text_tokens: Tensor, generated_audio_tokens: Tensor, generated_modality_flag: Tensor)[source]#

Bases: BaseModel

Model response wrapper. Used to standardize the output from different models.

chat: BaseMllmChat | None#

The updated chat with full history if keep_history was True.

generated_audio_tokens: Tensor#

The generated audio tokens.

generated_modality_flag: Tensor#

The modality flag for the generated tokens.

generated_text_tokens: Tensor#

The generated text tokens.

model_config = {'arbitrary_types_allowed': True}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

Module contents#

Base module for mllm_shap.connectors package.