mllm_shap.connectors.base package#
Submodules#
mllm_shap.connectors.base.chat module#
Base class for chat state management.
- exception mllm_shap.connectors.base.chat.AllTextTokensFilteredOutError[source]#
Bases:
ValueErrorRaised when all text tokens are filtered out from the chat.
- add_note()#
Exception.add_note(note) – add a note to the exception
- args#
- with_traceback()#
Exception.with_traceback(tb) – set self.__traceback__ to tb and return self.
- class mllm_shap.connectors.base.chat.BaseMllmChat(device: device, empty_turn_sequences: set[str], get_new_chat_callable: Callable[[...], BaseMllmChat], token_filter: TokenFilter | None = None, system_roles_setup: SystemRolesSetup | None = None)[source]#
Bases:
ABCBase class for chat state management.
- Important: Audio tokens are always added for shap calculations.
When using audio segments, make sure they are only message within their respective turns.
- add_audio(audio_content: bytes, audio_format: str = 'mp3', _waveform: Tensor | None = None, _sample_rate: int | None = None, _internal: bool = False) None[source]#
Add audio content to the chat state.
- Parameters:
audio_content – The audio content in bytes.
audio_format – The format of the audio content (default is “mp3”).
- Raises:
ValueError – If audio_content is not non-empty bytes.
RuntimeError – If an error occurs in the underlying connector implementation.
- add_audio_with_transcript(audio_content: bytes, transcript: str | list[str], aligner: SpectrogramGuidedAligner, audio_format: str = 'mp3', attach_audio: bool = False) None[source]#
Add audio content along with its transcript to the chat state.
- Parameters:
audio_content – The audio content in bytes.
transcript – The transcript of the audio content.
aligner – The SpectrogramGuidedAligner instance for aligning audio and transcript.
audio_format – The format of the audio content (default is “mp3”).
attach_audio – Whether to attach raw audio bytes to each segment (default is False). If False, segments will have empty audio bytes to save memory.
- Raises:
ValueError – If audio_content is not non-empty bytes or transcript is not a non-empty string.
RuntimeError – If an error occurs in the underlying connector implementation.
- add_text(text: str) None[source]#
Add text to the chat state.
- Parameters:
text – The text to add.
- Raises:
ValueError – If text is not a non-empty string.
RuntimeError – If an error occurs in the underlying connector implementation.
- append(text: Tensor, audio_out: Tensor, modality_flag: Tensor, history_tracking_mode: ModelHistoryTrackingMode) None[source]#
Append text and audio tokens along with modality flags to the chat state. Assumes that entry data is correct and non system.
Warning
This method does not validate input data.
- Parameters:
text – The text tokens to append.
audio_out – The audio tokens to append (intended for model’s output).
modality_flag – The modality flags corresponding to the tokens.
history_tracking_mode – The mode for tracking chat history.
- Raises:
ValueError – If length mismatch occurs after appending.
RuntimeError – If an error occurs in the underlying connector implementation.
- attach_audio_to_segments(aligner: SpectrogramGuidedAligner, audio_content: bytes, audio_format: str = 'mp3', turn_number: int | None = None) None[source]#
Attach raw audio bytes to segments for a specific turn.
This is a helper method to materialize audio bytes for segments that were created without audio attachment (i.e., with attach_audio=False).
- Parameters:
aligner – The SpectrogramGuidedAligner instance to use for attaching audio.
audio_content – The audio content in bytes.
audio_format – The format of the audio content (default is “mp3”).
turn_number – The turn number to attach audio for. If None, uses the current turn.
- Raises:
ValueError – If no audio segments exist for the specified turn.
- property audio_tokens: Tensor#
Input audio tensor (tokens) in shape (T, K)
- property audio_tokens_mask: Tensor#
Boolean mask indicating positions of audio tokens in the input (
input_tokenstensor).
- audio_tokens_no_system_mask: Tensor#
A boolean tensor indicating which audio tokens are user-generated.
- property audio_tokens_no_system_mask_filtered: Tensor#
Boolean mask indicating which audio tokens are not system generated. Relative to
audio_tokens_mask.
- property cache: ExplainerCache | None#
Access for the explainer cache if set.
- decode_audio(audio_tokens: list[Tensor] | Tensor | None = None, sample_rate: int = 24000, audio_format: str = 'mp3') bytes[source]#
Decode the generated audio tokens.
Warning
This method does not validate input data.
- Parameters:
audio_tokens – Audio tokens to decode in format (T, K).
sample_rate – The sample rate for the decoded audio (default is 24,000 Hz).
audio_format – The desired output audio format (default is “mp3”).
- Returns:
The decoded audio content in bytes. Empty bytes if decoding is not available.
- Raises:
RuntimeError – If an error occurs in the underlying connector implementation.
- decode_text(text_tokens: list[Tensor] | Tensor | None = None) str[source]#
Decode the generated text tokens.
Warning
This method does not validate input data.
- Parameters:
text_tokens – The generated text tokens.
- Returns:
The decoded text.
- Raises:
RuntimeError – If an error occurs in the underlying connector implementation.
- empty_turn_sequences: list[Tensor]#
A tensor indicating the empty turn sequence.
- end_turn() None[source]#
End the current turn in the chat state.
Warning
For Developers: This method assumes cached property refresh is handled in
_new_turn()or by calling add_text/add_audio methods.- Raises:
ValueError – If no turn is active.
RuntimeError – If an error occurs in the underlying connector implementation.
- property external_group_ids: Tensor | None#
An optional external group IDs for tokens. Should be an integer tensor with size equal to the number of tokens in the chat and set directly before the explanation process.
All entries > 0 will be treated as belonging to the same group for SHAP value calculations, 0 entries will be ignored from calculations (shap_values_mask for them will be set to False). Takes precedence over
external_shap_values_mask.Forbids adding new tokens to the chat while set.
- property external_group_ids_first_positions: Tensor#
Get the positions (indices) of the first occurrences of each consecutive non-zero group ID in external_group_ids.
- Returns:
Tensor of positions (indices) of the first occurrence of each non-zero group ID, or None if not set.
- Raises:
ValueError – If external_group_ids is not set.
- property external_group_ids_positive_mask: Tensor#
Boolean mask indicating which tokens have positive group IDs.
- Returns:
Boolean mask indicating which tokens have positive group IDs.
- Raises:
ValueError – If external_group_ids is not set.
- property external_shap_values_mask: Tensor | None#
An optional external SHAP values mask. Should be a boolean tensor with size equal to the number of tokens in the chat.
If provided,
shap_values_maskwill be and-ed with this mask.Forbids adding new tokens to the chat while set.
- classmethod from_chat(mask: Tensor, chat: BaseMllmChat) BaseMllmChat[source]#
Create a new chat instance from an existing chat and a mask.
- Parameters:
mask – A boolean tensor indicating which messages to include.
chat – The existing chat instance to copy.
- Returns:
An instance of BaseMllmChat.
- Raises:
ValueError – If the mask size does not match the number of tokens in the chat, or if the mask is all False, or if all text tokens are filtered out.
RuntimeError – If audio segments are inconsistent with stored waveforms.
- get_conversation() list[list[ChatEntry]][source]#
Serialize the chat state to a JSON-compatible dictionary.
- Returns:
- A list of turns, where each turn is a list of ChatEntry objects. shap_values=None
indicates that SHAP values are not yet available.
- Raises:
NotImplementedError – If the chat contains audio segments.
Example:
[ [ ChatEntry( content_type=0, roles=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], content='<|im_start|>, system, \n, You, are, a helpful assistant that answers questions briefly...', # noqa: E501 # pylint: disable=line-too-long shap_values=None ) ], [ ChatEntry( content_type=0, roles=[2, 2, 2, 0, 0, 0, 0, 2, 2], content='<|im_start|>, user, \n, Who, are, you, ?, <|im_end|>, \n', shap_values=None ) ] ]
- get_new_chat_callable: Callable[[...], BaseMllmChat]#
A callable to create a new chat instance.
- property input_tokens: list[Tensor]#
Combined input tensor (text + audio).
- property input_tokens_num: int#
Total number of input tokens (text + audio).
- property is_system_turn: bool#
Flag indicating whether the current turn is a system turn.
- new_turn(speaker: Role) None[source]#
Start a new turn in the chat state.
Warning
For Developers: This method assumes cached property refresh is handled in
_new_turn()or by calling add_text/add_audio methods.- Parameters:
speaker – The role of the speaker for the new turn.
- Raises:
ValueError – If a turn is already active.
RuntimeError – If an error occurs in the underlying connector implementation.
- refresh(full: bool = False, shap: bool = False) None[source]#
Refresh cached properties, that is:
input_tokens
tokens_modality_flag
text_tokens_mask
text_tokens
audio_tokens_mask
audio_tokens
If full is True, also refresh:
audio_tokens_no_system_mask_filtered
text_tokens_no_system_mask_filtered
shap_values_mask
external_group_ids_first_positions
external_group_ids_positive_mask
If shap is True, will only refresh: - shap_values_mask - external_group_ids_first_positions - external_group_ids_positive_mask
- Parameters:
full – If True, refreshes all cached properties.
shap – If True, refreshes shap-related cached properties.
- property shap_values_mask: Tensor#
Boolean mask indicating which tokens should be considered for SHAP value calculations (i.e., non-system text tokens).
- Raises:
ValueError – If the external SHAP values mask size does not match the number of tokens in the chat.
RuntimeError – If external_group_ids is set but has no positive IDs.
- system_roles_setup: SystemRolesSetup#
A set of roles that are considered system roles. If Role.ASSISTANT is added, multi-turn assistant messages will also be marked as system and therefore excluded from shapley value calculations.
- property text_tokens: Tensor#
Input text tensor (tokens).
- property text_tokens_mask: Tensor#
Boolean mask indicating positions of text tokens in the input (
input_tokenstensor).
- text_tokens_no_system_mask: Tensor#
A boolean tensor indicating which text tokens are user-generated.
- property text_tokens_no_system_mask_filtered: Tensor#
Boolean mask indicating which text tokens are not system generated, after filtering out specified sequences. Relative to
text_tokens_mask.
- token_filter: TokenFilter#
The token filtering strategy.
- token_roles: Tensor#
A tensor indicating the role of each token in the chat.
- token_sequences_to_exclude: list[Tensor]#
A list of token IDs to exclude from processing.
- token_turns: Tensor#
A tensor indicating the turn structure of the chat.
- property tokens_modality_flag: Tensor#
The modality flag tensor indicating token types according to
ModalityFlagenum.
- torch_device: device#
The device on which tensors are stored.
- translate_groups_ids_mask(mask: Tensor) Tensor[source]#
Translate a mask over group IDs to a mask over all tokens.
- Parameters:
mask – A boolean tensor indicating which group IDs to include.
- Returns:
A boolean tensor indicating which tokens to include.
- turn_number: int#
The current turn number in the chat.
mllm_shap.connectors.base.chat_entry module#
Conversation entry data structure for audio and text modalities.
- class mllm_shap.connectors.base.chat_entry.ChatEntry(*, content_type: int, roles: list[int], content: list[str | bytes], shap_values: list[float | None] | None)[source]#
Bases:
BaseModelConversation entry data structure.
- content: list[str | bytes]#
Content of the entry, can be text (str) or audio bytes (bytes).
- content_type: int#
Modality of the content (e.g., text or audio), refer to
ModalityFlag.
- display() None[source]#
Display the ChatEntry content.
- Raises:
ValueError – If the number of roles does not match the number of content pieces.
Example:
BY: USER, SYSTEM TEXT CONTENT: <|im_start|> user Who are you ? <|im_end|>
- model_config = {'arbitrary_types_allowed': True}#
Configuration for pydantic model.
- roles: list[int]#
List of roles associated with this entry, refer to
Role.
- shap_values: list[float | None] | None#
SHAP values associated with the content tokens, if one have been computed for this entry.
mllm_shap.connectors.base.explainer_cache module#
Cache for explainer computations.
- class mllm_shap.connectors.base.explainer_cache.ExplainerCache(chat: BaseMllmChat, responses: list[ModelResponse], masks: Tensor, shap_values_mask: Tensor, *, calculated_by: int, n: int, had_different_masks: bool = False)[source]#
Bases:
BaseModelCache for explainer computations associated with a chat. Saves and validates calculated SHAP values, masks, and reduced embeddings.
- calculated_by: int#
Hash of the explainer that calculated the SHAP values.
- chat: BaseMllmChat#
The chat instance the cache is for.
- classmethod create(chat: BaseMllmChat, explainer_hash: int, responses: list[ModelResponse], masks: Tensor, normalized_values: Tensor, shap_values_mask: Tensor, values: Tensor | None = None) ExplainerCache[source]#
Create a new ExplainerCache instance.
- Parameters:
chat – The chat instance the cache is for.
explainer_hash – Hash of the explainer that calculated the SHAP values.
responses – The model responses used for SHAP calculations.
masks – The masks used for SHAP calculations.
values – The SHAP values calculated.
normalized_values – The normalized SHAP values calculated.
shap_values_mask – The mask indicating which SHAP values are relevant.
- Returns:
A new ExplainerCache instance.
- static extend_values(values: Tensor, shape: tuple[int, ...], dim: int, fill_value: Any, device: device) Tensor[source]#
Extend SHAP values to match the chat length.
- Parameters:
values – The SHAP values to extend.
shape – The target shape for extension.
dim – The dimension along which to extend.
fill_value – The value to use for extension.
device – The device to create the extended tensor on.
- Returns:
The extended SHAP values.
- had_different_masks: bool#
Whether the masks used for SHAP calculations differed from chat’s masks.
- masks: Tensor#
The masks used for SHAP calculations.
- model_config = {'arbitrary_types_allowed': True}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- model_post_init(context: Any, /) None#
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
- Parameters:
self – The BaseModel instance.
context – The context.
- n: int#
Index of last token used for SHAP calculations.
- property normalized_values: Tensor#
Normalized SHAP values.
- Raises:
ValueError – If SHAP values are no longer valid or have not been computed yet.
- responses: list[ModelResponse]#
The model responses used for SHAP calculations.
- shap_values_mask: Tensor#
The mask indicating which SHAP values are relevant.
- property values: Tensor | None#
SHAP values. Can be none if
HierarchicalExplaineris used.- Raises:
ValueError – If SHAP values are no longer valid or have not been computed yet.
mllm_shap.connectors.base.filters module#
Base classes for token filtering strategies.
- class mllm_shap.connectors.base.filters.TokenFilter(*, phrases_to_exclude: set[str])[source]#
Bases:
ABC,BaseModelBase class for token filtering strategies.
- model_config = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- phrases_to_exclude: set[str]#
Set of phrases to exclude from SHAP calculations.
mllm_shap.connectors.base.model module#
Base model connector class.
- class mllm_shap.connectors.base.model.BaseMllmModel(config: HuggingFaceModelConfig, device: device, processor: Any, model: Any, history_tracking_mode: ModelHistoryTrackingMode = ModelHistoryTrackingMode.TEXT)[source]#
Bases:
ABCBase class for model connectors.
- config: HuggingFaceModelConfig#
The model configuration.
- device: device#
The device to run the model on.
- abstractmethod generate(chat: BaseMllmChat, max_new_tokens: int = 128, model_config: ModelConfig = ModelConfig(text_temperature=0.0, text_top_k=1, audio_temperature=0.0, audio_top_k=1), keep_history: bool = False) ModelResponse[source]#
Generate audio based on the current chat state.
- Parameters:
chat – The current chat state.
max_new_tokens – The maximum number of new tokens to generate (default is 20).
model_config – Additional model configuration parameters.
keep_history – Whether to return chat state with full history or only generated content.
- Returns:
The updated chat state after generation.
- Return type:
- get_contextual_embeddings(*args: Any, static_embeddings: list[Tensor] | None = None, **kwargs: Any) list[Tensor][source]#
Get contextual embeddings for the current chat state.
- Parameters:
static_embeddings – Precomputed static embeddings (if any).
*args – Additional positional arguments for
get_static_embeddings(). Used if static_embeddings is None.**kwargs – Additional keyword arguments for
get_static_embeddings(). Used if static_embeddings is None.
- Returns:
- The context embeddings for the text and audio tokens, same format as in
- Raises:
ValueError – If static_embeddings is not an instance of Tensor.
- abstractmethod get_new_chat() BaseMllmChat[source]#
Get a new chat state for the model.
- abstractmethod get_static_embeddings(responses: list[ModelResponse]) list[Tensor][source]#
Get static embeddings for the current chat state.
- Parameters:
responses – The model responses to get embeddings for.
- Returns:
The static embeddings for the text and audio tokens.
- Raises:
ValueError – If responses is not a list of ModelResponse.
- history_tracking_mode: ModelHistoryTrackingMode#
The mode for tracking chat history.
- model: Any#
The model instance.
- processor: Any#
The model processor (tokenizer).
mllm_shap.connectors.base.model_response module#
Model response wrapper module.
- class mllm_shap.connectors.base.model_response.ModelResponse(*, chat: BaseMllmChat | None, generated_text_tokens: Tensor, generated_audio_tokens: Tensor, generated_modality_flag: Tensor)[source]#
Bases:
BaseModelModel response wrapper. Used to standardize the output from different models.
- chat: BaseMllmChat | None#
The updated chat with full history if keep_history was True.
- generated_audio_tokens: Tensor#
The generated audio tokens.
- generated_modality_flag: Tensor#
The modality flag for the generated tokens.
- generated_text_tokens: Tensor#
The generated text tokens.
- model_config = {'arbitrary_types_allowed': True}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
Module contents#
Base module for mllm_shap.connectors package.