"""LiquidAudio chat state."""
import math
from collections.abc import Callable
from copy import deepcopy
from functools import cached_property
from logging import Logger
from typing import Any, Iterable, Literal, cast
import torch
from liquid_audio import ChatState as _ChatState
from liquid_audio import LFMModality
from torch import Tensor
from ...utils.logger import get_logger
from ...utils.other import safe_mask
from ..base.chat import BaseMllmChat
from ..base.filters import TokenFilter
from ..enums import ModalityFlag, ModelHistoryTrackingMode, Role, SystemRolesSetup
logger: Logger = get_logger(__name__)
[docs]
class LiquidAudioChat(BaseMllmChat, _ChatState): # type: ignore[misc]
"""Represents the chat state for a LiquidAudio model.
Handles text and audio token sequences, speaker roles, and special turn markers.
Includes configuration for audio input/output shapes and empty token handling.
"""
audio_empty_value: float = torch.finfo(torch.float32).min
"""Represents a placeholder value for empty audio tokens."""
validate_from_chat: bool
"""Determines whether to validate the chat state when creating new instances."""
START_MARK: str = "<|startoftext|>"
"""Marker indicating the start of a text sequence."""
EMPTY_SYSTEM_TURN: str = "<|im_start|>Role.SYSTEM\n<|im_end|>\n"
"""Marker representing an empty system turn."""
EMPTY_ASSISTANT_TURN: str = "<|im_start|>Role.ASSISTANT\n<|im_end|>\n"
"""Marker representing an empty assistant turn."""
EMPTY_USER_TURN: str = "<|im_start|>user\n<|im_end|>\n"
"""Marker representing an empty user turn."""
AUDIO_IN_SHAPE: int = 128
"""Number of audio codebooks used for audio input tokens."""
AUDIO_OUT_SHAPE: int = 8
"""Number of audio codebooks used for audio output tokens."""
_TWO_DIMS: int = 2
_SINGLE_BATCH: int = 1
_SHARED_ATTRIBUTES: frozenset[str] = frozenset({
"proc", # processor - large, read-only
"_logger",
})
# for each element x in _audio_map:
# x > 0 -> index in audio_out + 1
# x < 0 -> -(index in audio_in + 1)
_audio_map: Tensor
# relies on ChatState.text, ChatState.audio_in, ChatState.audio_out, ChatState.modality_flag
# both audio are in (K, T) format
# pylint: disable=too-many-arguments,too-many-positional-arguments
def __init__(
self,
device: torch.device,
validate_from_chat: bool = False,
empty_turn_sequences: set[str] | None = None,
token_filter: TokenFilter | None = None,
system_roles_setup: SystemRolesSetup | None = None,
get_new_chat_callable: Callable[..., "LiquidAudioChat"] | None = None,
**liquid_kwargs: Any,
) -> None:
"""
Initialize LiquidAudioChat.
Args:
device: The device to use for tensors.
validate_from_chat: Whether to validate chat state when creating new instances.
empty_turn_sequences: String sequences representing empty turns to consider.
token_filter: Token filtering strategy to apply.
system_roles_setup: Configuration for system role handling.
liquid_kwargs: Additional keyword arguments for ChatState.
"""
_ChatState.__init__(self, **liquid_kwargs)
_additional_empty_turn_sequences = {
LiquidAudioChat.EMPTY_SYSTEM_TURN,
LiquidAudioChat.EMPTY_ASSISTANT_TURN,
LiquidAudioChat.EMPTY_USER_TURN,
}
# Consider empty turns with start mark as well
for e in _additional_empty_turn_sequences.copy():
_additional_empty_turn_sequences.add(LiquidAudioChat.START_MARK + e)
empty_turn_sequences = empty_turn_sequences or set()
empty_turn_sequences = empty_turn_sequences.union(_additional_empty_turn_sequences)
BaseMllmChat.__init__(
self,
device=device,
empty_turn_sequences=empty_turn_sequences,
token_filter=token_filter,
system_roles_setup=system_roles_setup,
get_new_chat_callable=get_new_chat_callable # type: ignore[arg-type]
)
# mark starting tokens as system
self.speaker = Role.SYSTEM
self._after_add(1, text_added=True, refresh=True)
self.speaker = None
self.validate_from_chat = validate_from_chat
self._audio_map = torch.empty((0,), dtype=torch.long, device=self.torch_device)
# assume `_{}` are protected methods from BaseMllmChat
# pylint: disable=too-many-locals,protected-access,too-many-branches,too-many-statements
@classmethod
def _set_new_instance(
cls: type["LiquidAudioChat"],
full_mask: Tensor,
text_mask_relative: Tensor,
audio_mask_relative: Tensor,
chat: "LiquidAudioChat", # type: ignore[override]
) -> "LiquidAudioChat":
new_instance: "LiquidAudioChat" = deepcopy(chat)
# filter out text tokens based on the text_mask
# masking done on new_instance as it can mutate the tensors
new_instance.text = safe_mask(new_instance.text, text_mask_relative)
new_instance.text_tokens_no_system_mask = safe_mask(new_instance.text_tokens_no_system_mask, text_mask_relative)
# split audio mask into input and output parts
# masks relative to audio tokens
# this is calculated before filtering out audio tokens
audio_in_mask_relative, audio_out_mask_relative = chat._get_relative_audio_masks()
# audio map is a list of indices in audio_in and audio_out
# after removing some audio tokens, we need to update the audio map accordingly
final_audio_in_relative = audio_mask_relative[audio_in_mask_relative]
final_audio_out_relative = audio_mask_relative[audio_out_mask_relative]
# calculate index shifts due to removed tokens, make it
# relative to new audio map and token type (i.e., audio in or out)
removed_audio_in_relative_shift = torch.cumsum((~final_audio_in_relative).to(torch.long), dim=0)[
final_audio_in_relative
]
removed_audio_out_relative_shift = torch.cumsum((~final_audio_out_relative).to(torch.long), dim=0)[
final_audio_out_relative
]
# pick < 0 --> audio in, by final_audio_in_relative - what to keep, and adjust indices
new_audio_map_in = (
chat._audio_map[chat._audio_map < 0][final_audio_in_relative] + removed_audio_in_relative_shift
)
# pick > 0 --> audio out, by final_audio_out_relative - what to keep, and adjust indices
new_audio_map_out = (
chat._audio_map[chat._audio_map > 0][final_audio_out_relative] - removed_audio_out_relative_shift
)
new_instance._audio_map = torch.cat(
[new_audio_map_in, new_audio_map_out],
dim=0,
)
chunk = LiquidAudioChat.AUDIO_OUT_SHAPE # 8
t_frames = new_instance.audio_in.shape[1]
# Build frame mask respecting audio segment boundaries from audio_in_lens
# Each audio segment has its own length in audio_in_lens
frame_mask_list: list[Tensor] = []
token_idx = 0 # Index into final_audio_in_relative (token-level mask)
for seg_idx in range(chat.audio_in_lens.shape[0]):
seg_frame_count = int(chat.audio_in_lens[seg_idx].item())
# Number of tokens for this segment (ceiling division)
seg_token_count = (seg_frame_count + chunk - 1) // chunk
seg_frame_start = 0
for _ in range(seg_token_count):
if token_idx < len(final_audio_in_relative):
keep_val = bool(final_audio_in_relative[token_idx].item())
else:
keep_val = False
token_idx += 1
# How many frames does this token cover in this segment?
frames_for_token = min(chunk, seg_frame_count - seg_frame_start)
if frames_for_token > 0:
frame_mask_list.append(
torch.full(
(frames_for_token,),
fill_value=keep_val,
dtype=torch.bool,
device=new_instance.torch_device,
)
)
seg_frame_start += frames_for_token
if len(frame_mask_list) > 0:
final_audio_in_frame_mask = torch.cat(frame_mask_list, dim=0)
else:
final_audio_in_frame_mask = torch.empty(0, dtype=torch.bool, device=new_instance.torch_device)
# Ensure frame mask matches audio_in frames
if final_audio_in_frame_mask.shape[0] != t_frames:
# Pad or truncate to match
if final_audio_in_frame_mask.shape[0] < t_frames:
padding = torch.zeros(
t_frames - final_audio_in_frame_mask.shape[0],
dtype=torch.bool,
device=new_instance.torch_device,
)
final_audio_in_frame_mask = torch.cat([final_audio_in_frame_mask, padding], dim=0)
else:
final_audio_in_frame_mask = final_audio_in_frame_mask[:t_frames]
new_instance.audio_in = safe_mask(
new_instance.audio_in,
final_audio_in_frame_mask,
)
new_instance.audio_out = safe_mask(
new_instance.audio_out,
final_audio_out_relative,
)
# Update audio_in_lens based on kept frames per segment
frame_offset = 0
for i in range(new_instance.audio_in_lens.shape[0]):
original_len = int(chat.audio_in_lens[i].item()) # Use original chat's lens
kept_frames = final_audio_in_frame_mask[frame_offset: frame_offset + original_len].sum().item()
new_instance.audio_in_lens[i] = kept_frames
frame_offset += original_len
new_instance.audio_in_lens = new_instance.audio_in_lens[new_instance.audio_in_lens > 0]
if chat.validate_from_chat:
if new_instance._audio_map.shape[0] != audio_mask_relative.sum().item():
raise ValueError("audio_map shape does not match number of audio tokens after filtering.")
indices_in = -new_instance._audio_map[new_instance._audio_map < 0] - 1
if indices_in.numel() > 0 and indices_in.max() >= new_instance.audio_in.shape[1]:
raise ValueError("audio_in index out of bounds after filtering.")
indices_out = new_instance._audio_map[new_instance._audio_map > 0] - 1
if indices_out.numel() > 0 and indices_out.max() >= new_instance.audio_out.shape[1]:
raise ValueError("audio_out index out of bounds after filtering.")
new_instance.modality_flag = safe_mask(new_instance.modality_flag, full_mask)
return new_instance
@cached_property
def input_tokens(self) -> list[Tensor]:
text_mask = self.text_tokens_mask
audio_mask = self.audio_tokens_mask
# Total number of tokens
total_len = len(text_mask)
result: list[Tensor] = [torch.empty(0)] * total_len
a_idx = t_idx = 0
for i, is_audio in enumerate(audio_mask):
if is_audio:
token_idx = self.audio_tokens[a_idx]
if token_idx < 0: # audio in
result[i] = self.audio_in[..., -token_idx - 1].unsqueeze(-1)
else: # audio out
result[i] = self.audio_out[..., token_idx - 1].unsqueeze(-1)
a_idx += 1
else:
result[i] = self.text_tokens[t_idx].unsqueeze(-1)
t_idx += 1
return result
@cached_property
def tokens_modality_flag(self) -> Tensor:
modality_flag = torch.full_like(self.modality_flag[0], ModalityFlag.AUDIO)
modality_flag[self.modality_flag[0] == LFMModality.TEXT] = ModalityFlag.TEXT
return modality_flag
@cached_property
def text_tokens(self) -> Tensor:
return cast(Tensor, self.text[0])
@cached_property
def audio_tokens(self) -> Tensor: # return audio tokens map
return self._audio_map
def _decode_text(self, text_tokens: Tensor) -> str:
# Processor decode may return list[str] for batched inputs; normalize to str.
tt = text_tokens
if tt.ndim == self._TWO_DIMS and tt.shape[0] == self._SINGLE_BATCH:
tt = tt[0]
elif tt.ndim > 1:
tt = tt.reshape(-1)
decoded = self.proc.text.decode(tt)
if isinstance(decoded, list):
return "".join(decoded)
return decoded
def _decode_audio(self, audio_tokens: Tensor) -> Tensor | None: # pylint: disable=too-many-branches
if len(audio_tokens.shape) == 1:
logger.debug("Decoding audio tokens based on indices from _audio_map.")
sign = torch.sign(audio_tokens)
if sign.all(): # audio in
audio_tokens = self.audio_in[audio_tokens - 1]
elif not sign.any(): # audio out
audio_tokens = self.audio_out[-audio_tokens - 1]
else:
raise ValueError("audio_tokens should contain either only audio in or only audio out tokens.")
# input tokens
if audio_tokens.shape[0] == LiquidAudioChat.AUDIO_IN_SHAPE:
logger.debug("Decoding audio in...")
# logger.warning("Decoding audio in tokens is not supported.")
return None
# audio out tokens
if audio_tokens.shape[0] == LiquidAudioChat.AUDIO_OUT_SHAPE:
logger.debug("Decoding audio out...")
mimi_codes = audio_tokens.unsqueeze(0)
# -validation/clamp of code indices
mimi_codes = mimi_codes.to(dtype=torch.long, device=self.torch_device, non_blocking=True)
# try to infer per-codebook sizes from quantizer internals
sizes: list[int] = []
try:
q = self.proc.mimi.quantizer
if hasattr(q, "vq") and hasattr(q.vq, "layers") and q.vq.layers is not None:
for layer in cast(Iterable[Any], q.vq.layers):
codebook = getattr(layer, "_codebook", None) or getattr(layer, "codebook", None)
emb = getattr(codebook, "embedding", None)
if emb is None:
raise AttributeError("No embedding on codebook")
sizes.append(int(emb.shape[0]))
else:
# conservative fallback: assume 2048 entries per codebook
sizes = [2048] * mimi_codes.shape[1]
except Exception as e: # pylint: disable=broad-except
logger.warning("Could not introspect codebook sizes (%s). Falling back to 2048.", e)
sizes = [2048] * mimi_codes.shape[1]
num_codebooks = mimi_codes.shape[1]
for k in range(min(num_codebooks, len(sizes))):
n = sizes[k]
ck = mimi_codes[:, k, :] # (B, T)
# detect OOR
if (ck >= n).any() or (ck < 0).any():
mn = int(ck.min().item())
mx = int(ck.max().item())
logger.warning("Audio code OOR on codebook %d: min=%d max=%d valid=[0,%d). Clamping.", k, mn, mx, n)
ck.clamp_(0, n - 1)
return cast(Tensor, self.proc.mimi.decode(mimi_codes).squeeze(0))
raise ValueError(
f"audio tokens first dimension should be either {LiquidAudioChat.AUDIO_OUT_SHAPE} "
f"(audio out) or {LiquidAudioChat.AUDIO_IN_SHAPE} (audio in)."
)
def _add_text(self, text: str) -> int:
starting_tokens_num = self.text.shape[1]
_ChatState.add_text(self, text)
return int(self.text.shape[1] - starting_tokens_num)
def _add_audio(self, waveform: Tensor, sample_rate: int) -> int:
starting_tokens_num = self.audio_in.shape[1]
_ChatState.add_audio(self, waveform, sample_rate)
delta_cols = int(self.audio_in.shape[1] - starting_tokens_num)
added_tokens_num = math.ceil(delta_cols / LiquidAudioChat.AUDIO_OUT_SHAPE)
# update audio map
self._audio_map = torch.cat(
[
self._audio_map,
-(
torch.arange(
starting_tokens_num // LiquidAudioChat.AUDIO_OUT_SHAPE,
starting_tokens_num // LiquidAudioChat.AUDIO_OUT_SHAPE + added_tokens_num,
dtype=torch.long,
device=self.torch_device,
)
+ 1
),
],
dim=0,
)
return added_tokens_num
def _append(
self,
text: Tensor,
audio_out: Tensor,
modality_flag: Tensor,
history_tracking_mode: ModelHistoryTrackingMode,
) -> tuple[int, int]:
starting_text_tokens_num = self.text[0].shape[0]
starting_audio_tokens_num = self.audio_out[0].shape[0]
if history_tracking_mode == ModelHistoryTrackingMode.TEXT:
audio_out = torch.empty((self.codebooks, 0), dtype=audio_out.dtype, device=audio_out.device)
modality_flag = modality_flag[modality_flag == LFMModality.TEXT].unsqueeze(0)
elif history_tracking_mode == ModelHistoryTrackingMode.AUDIO:
text = torch.empty((1, 0), dtype=text.dtype, device=text.device)
modality_flag = modality_flag[modality_flag != LFMModality.TEXT].unsqueeze(0)
# else: keep both text and audio_out as is
_ChatState.append(self, text, audio_out, modality_flag)
# update audio map
self._audio_map = torch.cat(
[
self._audio_map,
torch.arange(
starting_audio_tokens_num,
self.audio_out.shape[1],
dtype=torch.long,
device=self.torch_device,
)
+ 1,
],
dim=0,
)
return (
self.text[0].shape[0] - starting_text_tokens_num,
self.audio_out[0].shape[0] - starting_audio_tokens_num,
)
def _new_turn(self, speaker: Role) -> None:
role: Literal["system", "user", "assistant"]
if speaker == Role.SYSTEM:
role = "system"
elif speaker == Role.USER:
role = "user"
else: # Role.ASSISTANT
role = "assistant"
_ChatState.new_turn(self, role)
def _end_turn(self) -> None:
_ChatState.end_turn(self)
def _get_tokens_sequences_to_exclude(self, phrases_to_exclude: set[str]) -> list[Tensor]:
token_sequences_to_exclude: list[Tensor] = []
for phrase in phrases_to_exclude:
token_ids = self.proc.text.encode(phrase, add_special_tokens=False)
token_sequences_to_exclude.append(torch.tensor(token_ids, device=self.torch_device))
return token_sequences_to_exclude
def _get_relative_audio_masks(self) -> tuple[Tensor, Tensor]:
"""
Get relative audio in and out masks based on the modality flag
(relative to audio tokens only).
Returns:
tuple[Tensor, Tensor]: A tuple containing the relative audio in mask and audio out mask
"""
audio_in_mask = self.modality_flag[0] == LFMModality.AUDIO_IN
audio_out_mask = self.modality_flag[0] == LFMModality.AUDIO_OUT
audio_mask = audio_in_mask | audio_out_mask
audio_in_mask_relative = audio_mask[audio_mask].clone()
audio_in_mask_relative[audio_out_mask[audio_mask]] = False
audio_out_mask_relative = audio_mask[audio_mask].clone()
audio_out_mask_relative[audio_in_mask[audio_mask]] = False
return audio_in_mask_relative, audio_out_mask_relative