Source code for mllm_shap.connectors.base.chat_entry
"""Conversation entry data structure for audio and text modalities."""
from typing import cast
from pydantic import BaseModel
from pydantic import ConfigDict
from ...utils.audio import display_audio
from ..enums import ModalityFlag, Role
[docs]
class ChatEntry(BaseModel):
"""Conversation entry data structure."""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Configuration for pydantic model."""
content_type: int
"""Modality of the content (e.g., text or audio), refer to :class:`ModalityFlag`."""
roles: list[int]
"""List of roles associated with this entry, refer to :class:`Role`."""
content: list[str | bytes]
"""Content of the entry, can be text (str) or audio bytes (bytes)."""
shap_values: list[float | None] | None
"""SHAP values associated with the content tokens, if one have been computed for this entry."""
[docs]
def display(self) -> None:
"""
Display the ChatEntry content.
Raises:
ValueError: If the number of roles does not match the number of content pieces.
Example:
::
BY: USER, SYSTEM
TEXT CONTENT:
<|im_start|> user
Who are you ? <|im_end|>
"""
if len(self.roles) != len(self.content):
raise ValueError("Number of roles must match number of content pieces.")
from IPython.display import display # pylint: disable=import-outside-toplevel
roles = sorted(set(self.roles))
roles_str: list[str] = [str(Role(v)) for v in roles]
print("BY: " + ", ".join(roles_str))
if self.content_type == ModalityFlag.TEXT.value:
print("TEXT CONTENT:")
print("\t" + " ".join(cast(list[str], self.content)).replace("\n", "\n\t"))
else: # ModalityFlag.AUDIO
print("AUDIO CONTENT:")
audio_bytes = b"".join(cast(list[bytes], self.content))
_ = display(display_audio(audio_bytes)) # type: ignore[no-untyped-call]
def __str__(self) -> str:
"""String representation of the ChatEntry."""
return self.__repr__()
def __repr__(self) -> str:
"""Official string representation of the ChatEntry."""
if self.content_type == ModalityFlag.TEXT.value:
content_str = ", ".join(cast(list[str], self.content)).replace("\n", "\\n")
else:
content_str = (
f"Audio bytes of total length {sum(len(c) if isinstance(c, bytes) else 0 for c in self.content)}"
)
# Limit to 50 characters
if len(content_str) > 50: # pylint: disable=magic-value-comparison
content_str = content_str[:100] + "..."
if len(self.roles) > 5: # pylint: disable=magic-value-comparison
roles_str = f"[{', '.join(str(Role(v)) for v in self.roles[:2])}"
roles_str += ", ..., "
roles_str += f"{', '.join(str(Role(v)) for v in self.roles[-2:])}]"
else:
roles_str = f"[{', '.join(str(Role(v)) for v in self.roles)}]"
if self.shap_values is not None and len(self.shap_values) > 5: # pylint: disable=magic-value-comparison
shap_values_str = f"[{', '.join(str(v) for v in self.shap_values[:2])}"
shap_values_str += ", ..., "
shap_values_str += f"{', '.join(str(v) for v in self.shap_values[-2:])}]"
else:
shap_values_str = str(self.shap_values)
# Build final representation
return (
f"ChatEntry("
f"content_type={self.content_type}, "
f"roles={roles_str}, "
f"content='{content_str}', "
f"shap_values={shap_values_str}"
f")"
)