mllm_shap.connectors package#
Subpackages#
- mllm_shap.connectors.base package
- Submodules
- mllm_shap.connectors.base.chat module
AllTextTokensFilteredOutErrorBaseMllmChatBaseMllmChat.add_audio()BaseMllmChat.add_audio_with_transcript()BaseMllmChat.add_text()BaseMllmChat.append()BaseMllmChat.attach_audio_to_segments()BaseMllmChat.audio_tokensBaseMllmChat.audio_tokens_maskBaseMllmChat.audio_tokens_no_system_maskBaseMllmChat.audio_tokens_no_system_mask_filteredBaseMllmChat.cacheBaseMllmChat.decode_audio()BaseMllmChat.decode_text()BaseMllmChat.empty_turn_sequencesBaseMllmChat.end_turn()BaseMllmChat.external_group_idsBaseMllmChat.external_group_ids_first_positionsBaseMllmChat.external_group_ids_positive_maskBaseMllmChat.external_shap_values_maskBaseMllmChat.from_chat()BaseMllmChat.get_conversation()BaseMllmChat.get_new_chat_callableBaseMllmChat.input_tokensBaseMllmChat.input_tokens_numBaseMllmChat.is_system_turnBaseMllmChat.new_turn()BaseMllmChat.refresh()BaseMllmChat.shap_values_maskBaseMllmChat.speakerBaseMllmChat.system_roles_setupBaseMllmChat.text_tokensBaseMllmChat.text_tokens_maskBaseMllmChat.text_tokens_no_system_maskBaseMllmChat.text_tokens_no_system_mask_filteredBaseMllmChat.token_filterBaseMllmChat.token_rolesBaseMllmChat.token_sequences_to_excludeBaseMllmChat.token_turnsBaseMllmChat.tokens_modality_flagBaseMllmChat.torch_deviceBaseMllmChat.translate_groups_ids_mask()BaseMllmChat.turn_number
- mllm_shap.connectors.base.chat_entry module
- mllm_shap.connectors.base.explainer_cache module
ExplainerCacheExplainerCache.calculated_byExplainerCache.chatExplainerCache.create()ExplainerCache.extend_masks()ExplainerCache.extend_values()ExplainerCache.had_different_masksExplainerCache.masksExplainerCache.model_configExplainerCache.model_post_init()ExplainerCache.nExplainerCache.normalized_valuesExplainerCache.responsesExplainerCache.shap_values_maskExplainerCache.values
- mllm_shap.connectors.base.filters module
- mllm_shap.connectors.base.model module
- mllm_shap.connectors.base.model_response module
- Module contents
- mllm_shap.connectors.liquid package
- Submodules
- mllm_shap.connectors.liquid.chat module
LiquidAudioChatLiquidAudioChat.AUDIO_IN_SHAPELiquidAudioChat.AUDIO_OUT_SHAPELiquidAudioChat.EMPTY_ASSISTANT_TURNLiquidAudioChat.EMPTY_SYSTEM_TURNLiquidAudioChat.EMPTY_USER_TURNLiquidAudioChat.START_MARKLiquidAudioChat.add_audio()LiquidAudioChat.add_audio_with_transcript()LiquidAudioChat.add_text()LiquidAudioChat.append()LiquidAudioChat.attach_audio_to_segments()LiquidAudioChat.audio_empty_valueLiquidAudioChat.audio_tokensLiquidAudioChat.audio_tokens_maskLiquidAudioChat.audio_tokens_no_system_maskLiquidAudioChat.audio_tokens_no_system_mask_filteredLiquidAudioChat.cacheLiquidAudioChat.decode_audio()LiquidAudioChat.decode_text()LiquidAudioChat.deviceLiquidAudioChat.empty_turn_sequencesLiquidAudioChat.end_turn()LiquidAudioChat.external_group_idsLiquidAudioChat.external_group_ids_first_positionsLiquidAudioChat.external_group_ids_positive_maskLiquidAudioChat.external_shap_values_maskLiquidAudioChat.from_chat()LiquidAudioChat.get_conversation()LiquidAudioChat.get_new_chat_callableLiquidAudioChat.input_tokensLiquidAudioChat.input_tokens_numLiquidAudioChat.is_system_turnLiquidAudioChat.model_inputsLiquidAudioChat.new_turn()LiquidAudioChat.refresh()LiquidAudioChat.shap_values_maskLiquidAudioChat.speakerLiquidAudioChat.system_roles_setupLiquidAudioChat.text_tokensLiquidAudioChat.text_tokens_maskLiquidAudioChat.text_tokens_no_system_maskLiquidAudioChat.text_tokens_no_system_mask_filteredLiquidAudioChat.token_filterLiquidAudioChat.token_rolesLiquidAudioChat.token_sequences_to_excludeLiquidAudioChat.token_turnsLiquidAudioChat.tokens_modality_flagLiquidAudioChat.torch_deviceLiquidAudioChat.translate_groups_ids_mask()LiquidAudioChat.turn_numberLiquidAudioChat.validate_from_chat
- mllm_shap.connectors.liquid.config module
- mllm_shap.connectors.liquid.model module
- Module contents
LiquidAudioLiquidAudioChatLiquidAudioChat.AUDIO_IN_SHAPELiquidAudioChat.AUDIO_OUT_SHAPELiquidAudioChat.EMPTY_ASSISTANT_TURNLiquidAudioChat.EMPTY_SYSTEM_TURNLiquidAudioChat.EMPTY_USER_TURNLiquidAudioChat.START_MARKLiquidAudioChat.add_audio()LiquidAudioChat.add_audio_with_transcript()LiquidAudioChat.add_text()LiquidAudioChat.append()LiquidAudioChat.attach_audio_to_segments()LiquidAudioChat.audio_empty_valueLiquidAudioChat.audio_tokensLiquidAudioChat.audio_tokens_maskLiquidAudioChat.audio_tokens_no_system_maskLiquidAudioChat.audio_tokens_no_system_mask_filteredLiquidAudioChat.cacheLiquidAudioChat.decode_audio()LiquidAudioChat.decode_text()LiquidAudioChat.deviceLiquidAudioChat.empty_turn_sequencesLiquidAudioChat.end_turn()LiquidAudioChat.external_group_idsLiquidAudioChat.external_group_ids_first_positionsLiquidAudioChat.external_group_ids_positive_maskLiquidAudioChat.external_shap_values_maskLiquidAudioChat.from_chat()LiquidAudioChat.get_conversation()LiquidAudioChat.get_new_chat_callableLiquidAudioChat.input_tokensLiquidAudioChat.input_tokens_numLiquidAudioChat.is_system_turnLiquidAudioChat.model_inputsLiquidAudioChat.new_turn()LiquidAudioChat.refresh()LiquidAudioChat.shap_values_maskLiquidAudioChat.speakerLiquidAudioChat.system_roles_setupLiquidAudioChat.text_tokensLiquidAudioChat.text_tokens_maskLiquidAudioChat.text_tokens_no_system_maskLiquidAudioChat.text_tokens_no_system_mask_filteredLiquidAudioChat.token_filterLiquidAudioChat.token_rolesLiquidAudioChat.token_sequences_to_excludeLiquidAudioChat.token_turnsLiquidAudioChat.tokens_modality_flagLiquidAudioChat.torch_deviceLiquidAudioChat.translate_groups_ids_mask()LiquidAudioChat.turn_numberLiquidAudioChat.validate_from_chat
Submodules#
mllm_shap.connectors.config module#
Configuration for Hugging Face interfaces.
- class mllm_shap.connectors.config.HuggingFaceModelConfig(*, repo_id: str, revision: str)[source]#
Bases:
BaseModelHolds the necessary information to load a model from the Hugging Face Hub.
- model_config = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- repo_id: str#
The repository ID of the model on Hugging Face.
- revision: str#
The specific revision or branch of the model to use.
- class mllm_shap.connectors.config.ModelConfig(*, text_temperature: float | None = 0.0, text_top_k: int | None = 1, audio_temperature: float | None = 0.0, audio_top_k: int | None = 1)[source]#
Bases:
BaseModelDefines settings for controlling text and audio generation behavior.
- audio_temperature: float | None#
Controls the randomness in audio generation.
- audio_top_k: int | None#
Restricts audio sampling to the top-k most probable tokens.
- model_config = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- text_temperature: float | None#
Controls the randomness in text generation.
- text_top_k: int | None#
Restricts text sampling to the top-k most probable tokens.
mllm_shap.connectors.enums module#
Configuration for possible roles and modalities.
- class mllm_shap.connectors.enums.ModalityFlag(*values)[source]#
Bases:
int,EnumDefines supported input/output modalities.
- AUDIO = 1#
Represents audio inputs or outputs.
- IGNORE = -1#
Ignore this token for modality-specific operations. For internal calculations references.
- TEXT = 0#
Represents text inputs or outputs.
- as_integer_ratio()#
Return a pair of integers, whose ratio is equal to the original int.
The ratio is in lowest terms and has a positive denominator.
>>> (10).as_integer_ratio() (10, 1) >>> (-10).as_integer_ratio() (-10, 1) >>> (0).as_integer_ratio() (0, 1)
- bit_count()#
Number of ones in the binary representation of the absolute value of self.
Also known as the population count.
>>> bin(13) '0b1101' >>> (13).bit_count() 3
- bit_length()#
Number of bits necessary to represent self in binary.
>>> bin(37) '0b100101' >>> (37).bit_length() 6
- conjugate()#
Returns self, the complex conjugate of any int.
- denominator#
the denominator of a rational number in lowest terms
- classmethod from_bytes(bytes, byteorder='big', *, signed=False)#
Return the integer represented by the given array of bytes.
- bytes
Holds the array of bytes to convert. The argument must either support the buffer protocol or be an iterable object producing bytes. Bytes and bytearray are examples of built-in objects that support the buffer protocol.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Indicates whether two’s complement is used to represent the integer.
- imag#
the imaginary part of a complex number
- is_integer()#
Returns True. Exists for duck type compatibility with float.is_integer.
- numerator#
the numerator of a rational number in lowest terms
- real#
the real part of a complex number
- to_bytes(length=1, byteorder='big', *, signed=False)#
Return an array of bytes representing an integer.
- length
Length of bytes object to use. An OverflowError is raised if the integer is not representable with the given number of bytes. Default is length 1.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Determines whether two’s complement is used to represent the integer. If signed is False and a negative integer is given, an OverflowError is raised.
- class mllm_shap.connectors.enums.ModelHistoryTrackingMode(*values)[source]#
Bases:
int,EnumControls which model outputs are tracked in history.
- AUDIO = 1#
Track only audio outputs.
- TEXT = 0#
Track only text outputs.
- TEXT_AUDIO = 2#
Track both text and audio outputs.
- as_integer_ratio()#
Return a pair of integers, whose ratio is equal to the original int.
The ratio is in lowest terms and has a positive denominator.
>>> (10).as_integer_ratio() (10, 1) >>> (-10).as_integer_ratio() (-10, 1) >>> (0).as_integer_ratio() (0, 1)
- bit_count()#
Number of ones in the binary representation of the absolute value of self.
Also known as the population count.
>>> bin(13) '0b1101' >>> (13).bit_count() 3
- bit_length()#
Number of bits necessary to represent self in binary.
>>> bin(37) '0b100101' >>> (37).bit_length() 6
- conjugate()#
Returns self, the complex conjugate of any int.
- denominator#
the denominator of a rational number in lowest terms
- classmethod from_bytes(bytes, byteorder='big', *, signed=False)#
Return the integer represented by the given array of bytes.
- bytes
Holds the array of bytes to convert. The argument must either support the buffer protocol or be an iterable object producing bytes. Bytes and bytearray are examples of built-in objects that support the buffer protocol.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Indicates whether two’s complement is used to represent the integer.
- imag#
the imaginary part of a complex number
- is_integer()#
Returns True. Exists for duck type compatibility with float.is_integer.
- numerator#
the numerator of a rational number in lowest terms
- real#
the real part of a complex number
- to_bytes(length=1, byteorder='big', *, signed=False)#
Return an array of bytes representing an integer.
- length
Length of bytes object to use. An OverflowError is raised if the integer is not representable with the given number of bytes. Default is length 1.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Determines whether two’s complement is used to represent the integer. If signed is False and a negative integer is given, an OverflowError is raised.
- class mllm_shap.connectors.enums.Role(*values)[source]#
Bases:
int,EnumDefines roles for participants in a conversation.
- ASSISTANT = 1#
Represents outputs generated by the assistant.
- SYSTEM = 2#
Represents system-level inputs or outputs.
- USER = 0#
Represents user-provided inputs.
- as_integer_ratio()#
Return a pair of integers, whose ratio is equal to the original int.
The ratio is in lowest terms and has a positive denominator.
>>> (10).as_integer_ratio() (10, 1) >>> (-10).as_integer_ratio() (-10, 1) >>> (0).as_integer_ratio() (0, 1)
- bit_count()#
Number of ones in the binary representation of the absolute value of self.
Also known as the population count.
>>> bin(13) '0b1101' >>> (13).bit_count() 3
- bit_length()#
Number of bits necessary to represent self in binary.
>>> bin(37) '0b100101' >>> (37).bit_length() 6
- conjugate()#
Returns self, the complex conjugate of any int.
- denominator#
the denominator of a rational number in lowest terms
- classmethod from_bytes(bytes, byteorder='big', *, signed=False)#
Return the integer represented by the given array of bytes.
- bytes
Holds the array of bytes to convert. The argument must either support the buffer protocol or be an iterable object producing bytes. Bytes and bytearray are examples of built-in objects that support the buffer protocol.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Indicates whether two’s complement is used to represent the integer.
- classmethod from_ordinal(ordinal: int) Role[source]#
Creates a Role enum member from its ordinal value.
- Parameters:
ordinal – The integer ordinal value of the role.
- Returns:
The corresponding Role enum member.
- Raises:
ValueError – If the ordinal does not correspond to any Role.
- imag#
the imaginary part of a complex number
- is_integer()#
Returns True. Exists for duck type compatibility with float.is_integer.
- numerator#
the numerator of a rational number in lowest terms
- real#
the real part of a complex number
- to_bytes(length=1, byteorder='big', *, signed=False)#
Return an array of bytes representing an integer.
- length
Length of bytes object to use. An OverflowError is raised if the integer is not representable with the given number of bytes. Default is length 1.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Determines whether two’s complement is used to represent the integer. If signed is False and a negative integer is given, an OverflowError is raised.
- class mllm_shap.connectors.enums.SystemRolesSetup(*values)[source]#
Bases:
int,EnumSpecifies how system roles affect explainability.
- NONE = 0#
No system role; all tokens are considered for explainability.
- SYSTEM = 1#
System role is active; explainability for system tokens is disabled. Includes steering tokens depending on connector implementation.
- SYSTEM_ASSISTANT = 2#
System-assistant role is active; explainability for assistant tokens is disabled.
- as_integer_ratio()#
Return a pair of integers, whose ratio is equal to the original int.
The ratio is in lowest terms and has a positive denominator.
>>> (10).as_integer_ratio() (10, 1) >>> (-10).as_integer_ratio() (-10, 1) >>> (0).as_integer_ratio() (0, 1)
- bit_count()#
Number of ones in the binary representation of the absolute value of self.
Also known as the population count.
>>> bin(13) '0b1101' >>> (13).bit_count() 3
- bit_length()#
Number of bits necessary to represent self in binary.
>>> bin(37) '0b100101' >>> (37).bit_length() 6
- conjugate()#
Returns self, the complex conjugate of any int.
- denominator#
the denominator of a rational number in lowest terms
- classmethod from_bytes(bytes, byteorder='big', *, signed=False)#
Return the integer represented by the given array of bytes.
- bytes
Holds the array of bytes to convert. The argument must either support the buffer protocol or be an iterable object producing bytes. Bytes and bytearray are examples of built-in objects that support the buffer protocol.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Indicates whether two’s complement is used to represent the integer.
- imag#
the imaginary part of a complex number
- is_integer()#
Returns True. Exists for duck type compatibility with float.is_integer.
- numerator#
the numerator of a rational number in lowest terms
- real#
the real part of a complex number
- to_bytes(length=1, byteorder='big', *, signed=False)#
Return an array of bytes representing an integer.
- length
Length of bytes object to use. An OverflowError is raised if the integer is not representable with the given number of bytes. Default is length 1.
- byteorder
The byte order used to represent the integer. If byteorder is ‘big’, the most significant byte is at the beginning of the byte array. If byteorder is ‘little’, the most significant byte is at the end of the byte array. To request the native byte order of the host system, use `sys.byteorder’ as the byte order value. Default is to use ‘big’.
- signed
Determines whether two’s complement is used to represent the integer. If signed is False and a negative integer is given, an OverflowError is raised.
mllm_shap.connectors.filters module#
Token filtering strategies for audio-shap connectors.
- class mllm_shap.connectors.filters.ExcludePunctuationTokensFilter(*, phrases_to_exclude: set[str] = {'!', ',', '.', ':', ';', '?'})[source]#
Bases:
TokenFilterA token filter that removes common punctuation tokens.
- model_config = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- phrases_to_exclude: set[str]#
Excludes standard inter-punctuation tokens.
- class mllm_shap.connectors.filters.KeepAllTokens(*, phrases_to_exclude: set[str] = {})[source]#
Bases:
TokenFilterA token filter that does not exclude any tokens.
- model_config = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- phrases_to_exclude: set[str]#
No tokens are excluded by this strategy.
Module contents#
Connectors module.
- class mllm_shap.connectors.LiquidAudio(device: device, *args: Any, **kwargs: Any)[source]#
Bases:
BaseMllmModelConnector for LiquidAudio model.
- Fields:
processor (LFM2AudioProcessor): The audio processor. model (LFM2AudioModel): The LiquidAudio model.
- config: HuggingFaceModelConfig#
The model configuration.
- device: device#
The device to run the model on.
- generate(chat: BaseMllmChat, max_new_tokens: int = 128, model_config: ModelConfig = ModelConfig(text_temperature=0.0, text_top_k=1, audio_temperature=0.0, audio_top_k=1), keep_history: bool = False) ModelResponse[source]#
Generate audio based on the current chat state.
- Parameters:
chat – The current chat state.
max_new_tokens – The maximum number of new tokens to generate (default is 20).
model_config – Additional model configuration parameters.
keep_history – Whether to return chat state with full history or only generated content.
- Returns:
The updated chat state after generation.
- Return type:
- get_contextual_embeddings(*args: Any, static_embeddings: list[Tensor] | None = None, **kwargs: Any) list[Tensor]#
Get contextual embeddings for the current chat state.
- Parameters:
static_embeddings – Precomputed static embeddings (if any).
*args – Additional positional arguments for
get_static_embeddings(). Used if static_embeddings is None.**kwargs – Additional keyword arguments for
get_static_embeddings(). Used if static_embeddings is None.
- Returns:
- The context embeddings for the text and audio tokens, same format as in
- Raises:
ValueError – If static_embeddings is not an instance of Tensor.
- get_new_chat(*args: Any, **kwargs: Any) LiquidAudioChat[source]#
Get a new chat state for the model.
- get_static_embeddings(responses: list[ModelResponse]) list[Tensor][source]#
Get static embeddings for the current chat state.
- Parameters:
responses – The model responses to get embeddings for.
- Returns:
The static embeddings for the text and audio tokens.
- Raises:
ValueError – If responses is not a list of ModelResponse.
- history_tracking_mode: ModelHistoryTrackingMode#
The mode for tracking chat history.
- model: LFM2AudioModel#
The model instance.
- processor: LFM2AudioProcessor#
The model processor (tokenizer).
- class mllm_shap.connectors.LiquidAudioChat(device: device, validate_from_chat: bool = False, empty_turn_sequences: set[str] | None = None, token_filter: TokenFilter | None = None, system_roles_setup: SystemRolesSetup | None = None, get_new_chat_callable: Callable[[...], LiquidAudioChat] | None = None, **liquid_kwargs: Any)[source]#
Bases:
BaseMllmChat,ChatStateRepresents the chat state for a LiquidAudio model.
Handles text and audio token sequences, speaker roles, and special turn markers. Includes configuration for audio input/output shapes and empty token handling.
- AUDIO_IN_SHAPE: int = 128#
Number of audio codebooks used for audio input tokens.
- AUDIO_OUT_SHAPE: int = 8#
Number of audio codebooks used for audio output tokens.
- EMPTY_ASSISTANT_TURN: str = '<|im_start|>Role.ASSISTANT\n<|im_end|>\n'#
Marker representing an empty assistant turn.
- EMPTY_SYSTEM_TURN: str = '<|im_start|>Role.SYSTEM\n<|im_end|>\n'#
Marker representing an empty system turn.
- EMPTY_USER_TURN: str = '<|im_start|>user\n<|im_end|>\n'#
Marker representing an empty user turn.
- START_MARK: str = '<|startoftext|>'#
Marker indicating the start of a text sequence.
- add_audio(audio_content: bytes, audio_format: str = 'mp3', _waveform: Tensor | None = None, _sample_rate: int | None = None, _internal: bool = False) None#
Add audio content to the chat state.
- Parameters:
audio_content – The audio content in bytes.
audio_format – The format of the audio content (default is “mp3”).
- Raises:
ValueError – If audio_content is not non-empty bytes.
RuntimeError – If an error occurs in the underlying connector implementation.
- add_audio_with_transcript(audio_content: bytes, transcript: str | list[str], aligner: SpectrogramGuidedAligner, audio_format: str = 'mp3', attach_audio: bool = False) None#
Add audio content along with its transcript to the chat state.
- Parameters:
audio_content – The audio content in bytes.
transcript – The transcript of the audio content.
aligner – The SpectrogramGuidedAligner instance for aligning audio and transcript.
audio_format – The format of the audio content (default is “mp3”).
attach_audio – Whether to attach raw audio bytes to each segment (default is False). If False, segments will have empty audio bytes to save memory.
- Raises:
ValueError – If audio_content is not non-empty bytes or transcript is not a non-empty string.
RuntimeError – If an error occurs in the underlying connector implementation.
- add_text(text: str) None#
Add text to the chat state.
- Parameters:
text – The text to add.
- Raises:
ValueError – If text is not a non-empty string.
RuntimeError – If an error occurs in the underlying connector implementation.
- append(text: Tensor, audio_out: Tensor, modality_flag: Tensor, history_tracking_mode: ModelHistoryTrackingMode) None#
Append text and audio tokens along with modality flags to the chat state. Assumes that entry data is correct and non system.
Warning
This method does not validate input data.
- Parameters:
text – The text tokens to append.
audio_out – The audio tokens to append (intended for model’s output).
modality_flag – The modality flags corresponding to the tokens.
history_tracking_mode – The mode for tracking chat history.
- Raises:
ValueError – If length mismatch occurs after appending.
RuntimeError – If an error occurs in the underlying connector implementation.
- attach_audio_to_segments(aligner: SpectrogramGuidedAligner, audio_content: bytes, audio_format: str = 'mp3', turn_number: int | None = None) None#
Attach raw audio bytes to segments for a specific turn.
This is a helper method to materialize audio bytes for segments that were created without audio attachment (i.e., with attach_audio=False).
- Parameters:
aligner – The SpectrogramGuidedAligner instance to use for attaching audio.
audio_content – The audio content in bytes.
audio_format – The format of the audio content (default is “mp3”).
turn_number – The turn number to attach audio for. If None, uses the current turn.
- Raises:
ValueError – If no audio segments exist for the specified turn.
- audio_empty_value: float = -3.4028234663852886e+38#
Represents a placeholder value for empty audio tokens.
- property audio_tokens: Tensor#
Input audio tensor (tokens) in shape (T, K)
- property audio_tokens_mask: Tensor#
Boolean mask indicating positions of audio tokens in the input (
input_tokenstensor).
- audio_tokens_no_system_mask: Tensor#
A boolean tensor indicating which audio tokens are user-generated.
- property audio_tokens_no_system_mask_filtered: Tensor#
Boolean mask indicating which audio tokens are not system generated. Relative to
audio_tokens_mask.
- property cache: ExplainerCache | None#
Access for the explainer cache if set.
- decode_audio(audio_tokens: list[Tensor] | Tensor | None = None, sample_rate: int = 24000, audio_format: str = 'mp3') bytes#
Decode the generated audio tokens.
Warning
This method does not validate input data.
- Parameters:
audio_tokens – Audio tokens to decode in format (T, K).
sample_rate – The sample rate for the decoded audio (default is 24,000 Hz).
audio_format – The desired output audio format (default is “mp3”).
- Returns:
The decoded audio content in bytes. Empty bytes if decoding is not available.
- Raises:
RuntimeError – If an error occurs in the underlying connector implementation.
- decode_text(text_tokens: list[Tensor] | Tensor | None = None) str#
Decode the generated text tokens.
Warning
This method does not validate input data.
- Parameters:
text_tokens – The generated text tokens.
- Returns:
The decoded text.
- Raises:
RuntimeError – If an error occurs in the underlying connector implementation.
- property device: device#
- empty_turn_sequences: list[Tensor]#
A tensor indicating the empty turn sequence.
- end_turn() None#
End the current turn in the chat state.
Warning
For Developers: This method assumes cached property refresh is handled in
_new_turn()or by calling add_text/add_audio methods.- Raises:
ValueError – If no turn is active.
RuntimeError – If an error occurs in the underlying connector implementation.
- property external_group_ids: Tensor | None#
An optional external group IDs for tokens. Should be an integer tensor with size equal to the number of tokens in the chat and set directly before the explanation process.
All entries > 0 will be treated as belonging to the same group for SHAP value calculations, 0 entries will be ignored from calculations (shap_values_mask for them will be set to False). Takes precedence over
external_shap_values_mask.Forbids adding new tokens to the chat while set.
- property external_group_ids_first_positions: Tensor#
Get the positions (indices) of the first occurrences of each consecutive non-zero group ID in external_group_ids.
- Returns:
Tensor of positions (indices) of the first occurrence of each non-zero group ID, or None if not set.
- Raises:
ValueError – If external_group_ids is not set.
- property external_group_ids_positive_mask: Tensor#
Boolean mask indicating which tokens have positive group IDs.
- Returns:
Boolean mask indicating which tokens have positive group IDs.
- Raises:
ValueError – If external_group_ids is not set.
- property external_shap_values_mask: Tensor | None#
An optional external SHAP values mask. Should be a boolean tensor with size equal to the number of tokens in the chat.
If provided,
shap_values_maskwill be and-ed with this mask.Forbids adding new tokens to the chat while set.
- classmethod from_chat(mask: Tensor, chat: BaseMllmChat) BaseMllmChat#
Create a new chat instance from an existing chat and a mask.
- Parameters:
mask – A boolean tensor indicating which messages to include.
chat – The existing chat instance to copy.
- Returns:
An instance of BaseMllmChat.
- Raises:
ValueError – If the mask size does not match the number of tokens in the chat, or if the mask is all False, or if all text tokens are filtered out.
RuntimeError – If audio segments are inconsistent with stored waveforms.
- get_conversation() list[list[ChatEntry]]#
Serialize the chat state to a JSON-compatible dictionary.
- Returns:
- A list of turns, where each turn is a list of ChatEntry objects. shap_values=None
indicates that SHAP values are not yet available.
- Raises:
NotImplementedError – If the chat contains audio segments.
Example:
[ [ ChatEntry( content_type=0, roles=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], content='<|im_start|>, system, \n, You, are, a helpful assistant that answers questions briefly...', # noqa: E501 # pylint: disable=line-too-long shap_values=None ) ], [ ChatEntry( content_type=0, roles=[2, 2, 2, 0, 0, 0, 0, 2, 2], content='<|im_start|>, user, \n, Who, are, you, ?, <|im_end|>, \n', shap_values=None ) ] ]
- get_new_chat_callable: Callable[[...], BaseMllmChat]#
A callable to create a new chat instance.
- property input_tokens: list[Tensor]#
Combined input tensor (text + audio).
- property input_tokens_num: int#
Total number of input tokens (text + audio).
- property is_system_turn: bool#
Flag indicating whether the current turn is a system turn.
- model_inputs: ClassVar[list[str]] = ['text', 'audio_in', 'audio_in_lens', 'audio_out', 'modality_flag']#
- new_turn(speaker: Role) None#
Start a new turn in the chat state.
Warning
For Developers: This method assumes cached property refresh is handled in
_new_turn()or by calling add_text/add_audio methods.- Parameters:
speaker – The role of the speaker for the new turn.
- Raises:
ValueError – If a turn is already active.
RuntimeError – If an error occurs in the underlying connector implementation.
- refresh(full: bool = False, shap: bool = False) None#
Refresh cached properties, that is:
input_tokens
tokens_modality_flag
text_tokens_mask
text_tokens
audio_tokens_mask
audio_tokens
If full is True, also refresh:
audio_tokens_no_system_mask_filtered
text_tokens_no_system_mask_filtered
shap_values_mask
external_group_ids_first_positions
external_group_ids_positive_mask
If shap is True, will only refresh: - shap_values_mask - external_group_ids_first_positions - external_group_ids_positive_mask
- Parameters:
full – If True, refreshes all cached properties.
shap – If True, refreshes shap-related cached properties.
- property shap_values_mask: Tensor#
Boolean mask indicating which tokens should be considered for SHAP value calculations (i.e., non-system text tokens).
- Raises:
ValueError – If the external SHAP values mask size does not match the number of tokens in the chat.
RuntimeError – If external_group_ids is set but has no positive IDs.
- system_roles_setup: SystemRolesSetup#
A set of roles that are considered system roles. If Role.ASSISTANT is added, multi-turn assistant messages will also be marked as system and therefore excluded from shapley value calculations.
- property text_tokens: Tensor#
Input text tensor (tokens).
- property text_tokens_mask: Tensor#
Boolean mask indicating positions of text tokens in the input (
input_tokenstensor).
- text_tokens_no_system_mask: Tensor#
A boolean tensor indicating which text tokens are user-generated.
- property text_tokens_no_system_mask_filtered: Tensor#
Boolean mask indicating which text tokens are not system generated, after filtering out specified sequences. Relative to
text_tokens_mask.
- token_filter: TokenFilter#
The token filtering strategy.
- token_roles: Tensor#
A tensor indicating the role of each token in the chat.
- token_sequences_to_exclude: list[Tensor]#
A list of token IDs to exclude from processing.
- token_turns: Tensor#
A tensor indicating the turn structure of the chat.
- property tokens_modality_flag: Tensor#
The modality flag tensor indicating token types according to
ModalityFlagenum.
- torch_device: device#
The device on which tensors are stored.
- translate_groups_ids_mask(mask: Tensor) Tensor#
Translate a mask over group IDs to a mask over all tokens.
- Parameters:
mask – A boolean tensor indicating which group IDs to include.
- Returns:
A boolean tensor indicating which tokens to include.
- turn_number: int#
The current turn number in the chat.
- validate_from_chat: bool#
Determines whether to validate the chat state when creating new instances.
- class mllm_shap.connectors.ModelConfig(*, text_temperature: float | None = 0.0, text_top_k: int | None = 1, audio_temperature: float | None = 0.0, audio_top_k: int | None = 1)[source]#
Bases:
BaseModelDefines settings for controlling text and audio generation behavior.
- audio_temperature: float | None#
Controls the randomness in audio generation.
- audio_top_k: int | None#
Restricts audio sampling to the top-k most probable tokens.
- model_config = {}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- text_temperature: float | None#
Controls the randomness in text generation.
- text_top_k: int | None#
Restricts text sampling to the top-k most probable tokens.
- class mllm_shap.connectors.SpectrogramGuidedAligner(device: device, model_name: str = 'facebook/wav2vec2-large-960h', model_revision: str = 'main', sample_rate: int = 16000, ctc_separator: str = '|')[source]#
Bases:
objectSpectrogram-Guided Forced Aligner using Wav2Vec2 and Torchaudio.
It takes raw audio bytes and a transcript (string or list of tokens) and produces time-aligned segments with refined boundaries.
The alignment process consists of several phases: 1. Acoustic Modeling: A CTC model (Wav2Vec2) maps audio to character probabilities. 2. Forced Alignment: Dynamic programming finds the optimal alignment path. 3. Boundary Refinement: Spectrogram features (Energy & Flux) refine boundaries. 4. Aggregation: Character-level segments are grouped into user-defined tokens.
- attach_audio_to_segments(segments: list[AudioSegment], audio_content: bytes | None = None, waveform: Tensor | None = None, original_sr: int | None = None, audio_format: str = 'mp3') None[source]#
Attach raw audio bytes to existing AudioSegment objects.
This is a helper method to materialize audio bytes for segments that were created without audio attachment (i.e., with attach_audio=False).
- Parameters:
segments – list of AudioSegment objects to attach audio to.
audio_content – Raw audio bytes. Either this or (waveform, original_sr) must be provided.
waveform – Audio waveform tensor. Either this and original_sr or audio_content must be provided.
original_sr – Original sampling rate of the waveform.
audio_format – Format of the audio content (default is “mp3”).
- Raises:
ValueError – If neither audio_content nor (waveform, original_sr) are provided.
- static normalize_text(text: str) str[source]#
Normalize text for alignment by stripping diacritics and keeping only alphanumeric characters.
This normalization is applied consistently to both the transcript used for forced alignment and the target segments used for aggregation to ensure matching.
- Parameters:
text – The text to normalize.
- Returns:
Normalized text (uppercase, no diacritics, only alphanumeric).
- class mllm_shap.connectors.TransformersCausalText(device: device, **kwargs: Any)[source]#
Bases:
BaseMllmModelConnector for classic Hugging Face causal LMs (text-only).
- Fields:
processor (PreTrainedTokenizerBase): the tokenizer model (PreTrainedModel): the causal LM (AutoModelForCausalLM)
- config: HuggingFaceModelConfig#
The model configuration.
- device: device#
The device to run the model on.
- generate(chat: BaseMllmChat, max_new_tokens: int = 128, model_config: ModelConfig = ModelConfig(text_temperature=0.0, text_top_k=1, audio_temperature=0.0, audio_top_k=1), keep_history: bool = False) ModelResponse[source]#
Generate audio based on the current chat state.
- Parameters:
chat – The current chat state.
max_new_tokens – The maximum number of new tokens to generate (default is 20).
model_config – Additional model configuration parameters.
keep_history – Whether to return chat state with full history or only generated content.
- Returns:
The updated chat state after generation.
- Return type:
- get_contextual_embeddings(*args: Any, static_embeddings: list[Tensor] | None = None, **kwargs: Any) list[Tensor]#
Get contextual embeddings for the current chat state.
- Parameters:
static_embeddings – Precomputed static embeddings (if any).
*args – Additional positional arguments for
get_static_embeddings(). Used if static_embeddings is None.**kwargs – Additional keyword arguments for
get_static_embeddings(). Used if static_embeddings is None.
- Returns:
- The context embeddings for the text and audio tokens, same format as in
- Raises:
ValueError – If static_embeddings is not an instance of Tensor.
- get_new_chat(**kwargs: Any) TransformersTextChat[source]#
Get a new chat state for the model.
- get_static_embeddings(responses: list[ModelResponse]) list[Tensor][source]#
Get static embeddings for the current chat state.
- Parameters:
responses – The model responses to get embeddings for.
- Returns:
The static embeddings for the text and audio tokens.
- Raises:
ValueError – If responses is not a list of ModelResponse.
- history_tracking_mode: ModelHistoryTrackingMode#
The mode for tracking chat history.
- model: PreTrainedModel#
The model instance.
- processor: PreTrainedTokenizerBase#
The model processor (tokenizer).
- class mllm_shap.connectors.TransformersTextChat(device: device, tokenizer: PreTrainedTokenizerBase, empty_turn_sequences: set[str] | None = None, token_filter: TokenFilter | None = None, system_roles_setup: SystemRolesSetup | None = None, get_new_chat_callable: Callable[[...], TransformersTextChat] | None = None)[source]#
Bases:
BaseMllmChatChat state for text-only causal LMs.
Stores only TEXT token IDs. AUDIO is unsupported and will warn+no-op.
- add_audio(audio_content: bytes, audio_format: str = 'mp3', _waveform: Tensor | None = None, _sample_rate: int | None = None, _internal: bool = False) None#
Add audio content to the chat state.
- Parameters:
audio_content – The audio content in bytes.
audio_format – The format of the audio content (default is “mp3”).
- Raises:
ValueError – If audio_content is not non-empty bytes.
RuntimeError – If an error occurs in the underlying connector implementation.
- add_audio_with_transcript(audio_content: bytes, transcript: str | list[str], aligner: SpectrogramGuidedAligner, audio_format: str = 'mp3', attach_audio: bool = False) None#
Add audio content along with its transcript to the chat state.
- Parameters:
audio_content – The audio content in bytes.
transcript – The transcript of the audio content.
aligner – The SpectrogramGuidedAligner instance for aligning audio and transcript.
audio_format – The format of the audio content (default is “mp3”).
attach_audio – Whether to attach raw audio bytes to each segment (default is False). If False, segments will have empty audio bytes to save memory.
- Raises:
ValueError – If audio_content is not non-empty bytes or transcript is not a non-empty string.
RuntimeError – If an error occurs in the underlying connector implementation.
- add_text(text: str) None#
Add text to the chat state.
- Parameters:
text – The text to add.
- Raises:
ValueError – If text is not a non-empty string.
RuntimeError – If an error occurs in the underlying connector implementation.
- append(text: Tensor, audio_out: Tensor, modality_flag: Tensor, history_tracking_mode: ModelHistoryTrackingMode) None#
Append text and audio tokens along with modality flags to the chat state. Assumes that entry data is correct and non system.
Warning
This method does not validate input data.
- Parameters:
text – The text tokens to append.
audio_out – The audio tokens to append (intended for model’s output).
modality_flag – The modality flags corresponding to the tokens.
history_tracking_mode – The mode for tracking chat history.
- Raises:
ValueError – If length mismatch occurs after appending.
RuntimeError – If an error occurs in the underlying connector implementation.
- apply_text_mask(text_mask_relative: Tensor) None[source]#
Apply a relative text mask to this chat instance (public on purpose).
- attach_audio_to_segments(aligner: SpectrogramGuidedAligner, audio_content: bytes, audio_format: str = 'mp3', turn_number: int | None = None) None#
Attach raw audio bytes to segments for a specific turn.
This is a helper method to materialize audio bytes for segments that were created without audio attachment (i.e., with attach_audio=False).
- Parameters:
aligner – The SpectrogramGuidedAligner instance to use for attaching audio.
audio_content – The audio content in bytes.
audio_format – The format of the audio content (default is “mp3”).
turn_number – The turn number to attach audio for. If None, uses the current turn.
- Raises:
ValueError – If no audio segments exist for the specified turn.
- property audio_tokens: Tensor#
Input audio tensor (tokens) in shape (T, K)
- property audio_tokens_mask: Tensor#
Boolean mask indicating positions of audio tokens in the input (
input_tokenstensor).
- audio_tokens_no_system_mask: Tensor#
A boolean tensor indicating which audio tokens are user-generated.
- property audio_tokens_no_system_mask_filtered: Tensor#
Boolean mask indicating which audio tokens are not system generated. Relative to
audio_tokens_mask.
- property cache: ExplainerCache | None#
Access for the explainer cache if set.
- decode_audio(audio_tokens: list[Tensor] | Tensor | None = None, sample_rate: int = 24000, audio_format: str = 'mp3') bytes#
Decode the generated audio tokens.
Warning
This method does not validate input data.
- Parameters:
audio_tokens – Audio tokens to decode in format (T, K).
sample_rate – The sample rate for the decoded audio (default is 24,000 Hz).
audio_format – The desired output audio format (default is “mp3”).
- Returns:
The decoded audio content in bytes. Empty bytes if decoding is not available.
- Raises:
RuntimeError – If an error occurs in the underlying connector implementation.
- decode_text(text_tokens: list[Tensor] | Tensor | None = None) str#
Decode the generated text tokens.
Warning
This method does not validate input data.
- Parameters:
text_tokens – The generated text tokens.
- Returns:
The decoded text.
- Raises:
RuntimeError – If an error occurs in the underlying connector implementation.
- empty_turn_sequences: list[Tensor]#
A tensor indicating the empty turn sequence.
- end_turn() None#
End the current turn in the chat state.
Warning
For Developers: This method assumes cached property refresh is handled in
_new_turn()or by calling add_text/add_audio methods.- Raises:
ValueError – If no turn is active.
RuntimeError – If an error occurs in the underlying connector implementation.
- property external_group_ids: Tensor | None#
An optional external group IDs for tokens. Should be an integer tensor with size equal to the number of tokens in the chat and set directly before the explanation process.
All entries > 0 will be treated as belonging to the same group for SHAP value calculations, 0 entries will be ignored from calculations (shap_values_mask for them will be set to False). Takes precedence over
external_shap_values_mask.Forbids adding new tokens to the chat while set.
- property external_group_ids_first_positions: Tensor#
Get the positions (indices) of the first occurrences of each consecutive non-zero group ID in external_group_ids.
- Returns:
Tensor of positions (indices) of the first occurrence of each non-zero group ID, or None if not set.
- Raises:
ValueError – If external_group_ids is not set.
- property external_group_ids_positive_mask: Tensor#
Boolean mask indicating which tokens have positive group IDs.
- Returns:
Boolean mask indicating which tokens have positive group IDs.
- Raises:
ValueError – If external_group_ids is not set.
- property external_shap_values_mask: Tensor | None#
An optional external SHAP values mask. Should be a boolean tensor with size equal to the number of tokens in the chat.
If provided,
shap_values_maskwill be and-ed with this mask.Forbids adding new tokens to the chat while set.
- classmethod from_chat(mask: Tensor, chat: BaseMllmChat) BaseMllmChat#
Create a new chat instance from an existing chat and a mask.
- Parameters:
mask – A boolean tensor indicating which messages to include.
chat – The existing chat instance to copy.
- Returns:
An instance of BaseMllmChat.
- Raises:
ValueError – If the mask size does not match the number of tokens in the chat, or if the mask is all False, or if all text tokens are filtered out.
RuntimeError – If audio segments are inconsistent with stored waveforms.
- get_conversation() list[list[ChatEntry]]#
Serialize the chat state to a JSON-compatible dictionary.
- Returns:
- A list of turns, where each turn is a list of ChatEntry objects. shap_values=None
indicates that SHAP values are not yet available.
- Raises:
NotImplementedError – If the chat contains audio segments.
Example:
[ [ ChatEntry( content_type=0, roles=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], content='<|im_start|>, system, \n, You, are, a helpful assistant that answers questions briefly...', # noqa: E501 # pylint: disable=line-too-long shap_values=None ) ], [ ChatEntry( content_type=0, roles=[2, 2, 2, 0, 0, 0, 0, 2, 2], content='<|im_start|>, user, \n, Who, are, you, ?, <|im_end|>, \n', shap_values=None ) ] ]
- get_new_chat_callable: Callable[[...], BaseMllmChat]#
A callable to create a new chat instance.
- property input_tokens: list[Tensor]#
Combined input tensor (text + audio).
- property input_tokens_num: int#
Total number of input tokens (text + audio).
- property is_system_turn: bool#
Flag indicating whether the current turn is a system turn.
- new_turn(speaker: Role) None#
Start a new turn in the chat state.
Warning
For Developers: This method assumes cached property refresh is handled in
_new_turn()or by calling add_text/add_audio methods.- Parameters:
speaker – The role of the speaker for the new turn.
- Raises:
ValueError – If a turn is already active.
RuntimeError – If an error occurs in the underlying connector implementation.
- refresh(full: bool = False, shap: bool = False) None#
Refresh cached properties, that is:
input_tokens
tokens_modality_flag
text_tokens_mask
text_tokens
audio_tokens_mask
audio_tokens
If full is True, also refresh:
audio_tokens_no_system_mask_filtered
text_tokens_no_system_mask_filtered
shap_values_mask
external_group_ids_first_positions
external_group_ids_positive_mask
If shap is True, will only refresh: - shap_values_mask - external_group_ids_first_positions - external_group_ids_positive_mask
- Parameters:
full – If True, refreshes all cached properties.
shap – If True, refreshes shap-related cached properties.
- property shap_values_mask: Tensor#
Boolean mask indicating which tokens should be considered for SHAP value calculations (i.e., non-system text tokens).
- Raises:
ValueError – If the external SHAP values mask size does not match the number of tokens in the chat.
RuntimeError – If external_group_ids is set but has no positive IDs.
- system_roles_setup: SystemRolesSetup#
A set of roles that are considered system roles. If Role.ASSISTANT is added, multi-turn assistant messages will also be marked as system and therefore excluded from shapley value calculations.
- property text_tokens: Tensor#
Input text tensor (tokens).
- property text_tokens_mask: Tensor#
Boolean mask indicating positions of text tokens in the input (
input_tokenstensor).
- text_tokens_no_system_mask: Tensor#
A boolean tensor indicating which text tokens are user-generated.
- property text_tokens_no_system_mask_filtered: Tensor#
Boolean mask indicating which text tokens are not system generated, after filtering out specified sequences. Relative to
text_tokens_mask.
- token_filter: TokenFilter#
The token filtering strategy.
- token_roles: Tensor#
A tensor indicating the role of each token in the chat.
- token_sequences_to_exclude: list[Tensor]#
A list of token IDs to exclude from processing.
- token_turns: Tensor#
A tensor indicating the turn structure of the chat.
- tokenizer: PreTrainedTokenizerBase#
- property tokens_modality_flag: Tensor#
The modality flag tensor indicating token types according to
ModalityFlagenum.
- torch_device: device#
The device on which tensors are stored.
- translate_groups_ids_mask(mask: Tensor) Tensor#
Translate a mask over group IDs to a mask over all tokens.
- Parameters:
mask – A boolean tensor indicating which group IDs to include.
- Returns:
A boolean tensor indicating which tokens to include.
- turn_number: int#
The current turn number in the chat.