"""LiquidAudio chat state."""
import math
import sys
from collections.abc import Callable
from copy import deepcopy
from functools import cached_property
from logging import Logger
from typing import Any, Iterable, Literal, cast
import torch
from liquid_audio import ChatState as _ChatState
from liquid_audio import LFMModality
from torch import Tensor
from ...utils.logger import get_logger
from ...utils.other import safe_mask
from ..base.chat import BaseMllmChat
from ..base.filters import TokenFilter
from ..enums import ModalityFlag, ModelHistoryTrackingMode, Role, SystemRolesSetup
logger: Logger = get_logger(__name__)
[docs]
class LiquidAudioChat(BaseMllmChat, _ChatState):
"""Represents the chat state for a LiquidAudio model.
Handles text and audio token sequences, speaker roles, and special turn markers.
Includes configuration for audio input/output shapes and empty token handling.
"""
audio_empty_value: float = torch.finfo(torch.float32).min
"""Represents a placeholder value for empty audio tokens."""
validate_from_chat: bool
"""Determines whether to validate the chat state when creating new instances."""
START_MARK: str = "<|startoftext|>"
"""Marker indicating the start of a text sequence."""
EMPTY_SYSTEM_TURN: str = "<|im_start|>Role.SYSTEM\n<|im_end|>\n"
"""Marker representing an empty system turn."""
EMPTY_ASSISTANT_TURN: str = "<|im_start|>Role.ASSISTANT\n<|im_end|>\n"
"""Marker representing an empty assistant turn."""
EMPTY_USER_TURN: str = "<|im_start|>user\n<|im_end|>\n"
"""Marker representing an empty user turn."""
AUDIO_IN_SHAPE: int = 128
"""Number of audio codebooks used for audio input tokens."""
AUDIO_OUT_SHAPE: int = 8
"""Number of audio codebooks used for audio output tokens."""
_SHARED_ATTRIBUTES: frozenset[str] = frozenset({
"proc", # processor - large, read-only
"_logger",
})
"""A set of attribute names that should be shared across chat instances created
from each other, to ensure consistency in SHAP calculations and token filtering a
cross derived chat instances. Includes the processor and logger, which are typically
large objects that should not be duplicated across chat instances. This is used
in the _set_new_instance method to copy these attributes from the original chat
instance to the new one, ensuring that derived chat instances have access to
the same processor and logger without needing to create new instances of
these potentially large objects."""
# for each element x in _audio_map:
# x > 0 -> index in audio_out + 1
# x < 0 -> -(index in audio_in + 1)
_audio_map: Tensor
"""A tensor mapping audio token positions to their corresponding indices in audio_in and audio_out."""
# relies on ChatState.text, ChatState.audio_in, ChatState.audio_out, ChatState.modality_flag
# both audio are in (K, T) format
def __init__(
self,
device: torch.device,
validate_from_chat: bool = False,
empty_turn_sequences: set[str] | None = None,
token_filter: TokenFilter | None = None,
system_roles_setup: SystemRolesSetup | None = None,
get_new_chat_callable: Callable[..., "LiquidAudioChat"] | None = None,
**liquid_kwargs: Any,
) -> None:
"""
Initialize LiquidAudioChat.
Args:
device: The device to use for tensors.
validate_from_chat: Whether to validate chat state when creating new instances.
empty_turn_sequences: String sequences representing empty turns to consider.
token_filter: Token filtering strategy to apply.
system_roles_setup: Configuration for system role handling.
liquid_kwargs: Additional keyword arguments for ChatState.
"""
_ChatState.__init__(self, **liquid_kwargs)
_additional_empty_turn_sequences = {
LiquidAudioChat.EMPTY_SYSTEM_TURN,
LiquidAudioChat.EMPTY_ASSISTANT_TURN,
LiquidAudioChat.EMPTY_USER_TURN,
}
# Consider empty turns with start mark as well
for e in _additional_empty_turn_sequences.copy():
_additional_empty_turn_sequences.add(LiquidAudioChat.START_MARK + e)
empty_turn_sequences = empty_turn_sequences or set()
empty_turn_sequences = empty_turn_sequences.union(
_additional_empty_turn_sequences
)
BaseMllmChat.__init__(
self,
device=device,
empty_turn_sequences=empty_turn_sequences,
token_filter=token_filter,
system_roles_setup=system_roles_setup,
get_new_chat_callable=get_new_chat_callable,
)
# mark starting tokens as system
self.speaker = Role.SYSTEM
self._after_add(1, text_added=True, refresh=True)
self.speaker = None
self.validate_from_chat = validate_from_chat
self._audio_map = torch.empty((0,), dtype=torch.long, device=self.torch_device)
# assume `_{}` are protected methods from BaseMllmChat
@classmethod
def _set_new_instance(
cls: type["LiquidAudioChat"],
full_mask: Tensor,
text_mask_relative: Tensor,
audio_mask_relative: Tensor,
chat: "LiquidAudioChat",
) -> "LiquidAudioChat":
new_instance: "LiquidAudioChat" = deepcopy(chat)
# filter out text tokens based on the text_mask
# masking done on new_instance as it can mutate the tensors
new_instance.text = safe_mask(new_instance.text, text_mask_relative)
new_instance.text_tokens_no_system_mask = safe_mask(
new_instance.text_tokens_no_system_mask, text_mask_relative
)
# split audio mask into input and output parts
# masks relative to audio tokens
# this is calculated before filtering out audio tokens
audio_in_mask_relative, audio_out_mask_relative = (
chat._get_relative_audio_masks()
)
# audio map is a list of indices in audio_in and audio_out
# after removing some audio tokens, we need to update the audio map accordingly
final_audio_in_relative = audio_mask_relative[audio_in_mask_relative]
final_audio_out_relative = audio_mask_relative[audio_out_mask_relative]
# calculate index shifts due to removed tokens, make it
# relative to new audio map and token type (i.e., audio in or out)
removed_audio_in_relative_shift = torch.cumsum(
(~final_audio_in_relative).to(torch.long), dim=0
)[final_audio_in_relative]
removed_audio_out_relative_shift = torch.cumsum(
(~final_audio_out_relative).to(torch.long), dim=0
)[final_audio_out_relative]
# pick < 0 --> audio in, by final_audio_in_relative - what to keep, and adjust indices
new_audio_map_in = (
chat._audio_map[chat._audio_map < 0][final_audio_in_relative]
+ removed_audio_in_relative_shift
)
# pick > 0 --> audio out, by final_audio_out_relative - what to keep, and adjust indices
new_audio_map_out = (
chat._audio_map[chat._audio_map > 0][final_audio_out_relative]
- removed_audio_out_relative_shift
)
new_instance._audio_map = torch.cat(
[new_audio_map_in, new_audio_map_out],
dim=0,
)
chunk = LiquidAudioChat.AUDIO_OUT_SHAPE # 8
t_frames = new_instance.audio_in.shape[1]
# Build frame mask respecting audio segment boundaries from audio_in_lens
# Each audio segment has its own length in audio_in_lens
frame_mask_list: list[Tensor] = []
token_idx = 0 # Index into final_audio_in_relative (token-level mask)
for seg_idx in range(chat.audio_in_lens.shape[0]):
seg_frame_count = int(chat.audio_in_lens[seg_idx].item())
# Number of tokens for this segment (ceiling division)
seg_token_count = (seg_frame_count + chunk - 1) // chunk
seg_frame_start = 0
for _ in range(seg_token_count):
if token_idx < len(final_audio_in_relative):
keep_val = bool(final_audio_in_relative[token_idx].item())
else:
keep_val = False
token_idx += 1
# How many frames does this token cover in this segment?
frames_for_token = min(chunk, seg_frame_count - seg_frame_start)
if frames_for_token > 0: # pragma: no branch
frame_mask_list.append(
torch.full(
(frames_for_token,),
fill_value=keep_val,
dtype=torch.bool,
device=new_instance.torch_device,
)
)
seg_frame_start += frames_for_token
if len(frame_mask_list) > 0:
final_audio_in_frame_mask = torch.cat(frame_mask_list, dim=0)
else:
final_audio_in_frame_mask = torch.empty(
0, dtype=torch.bool, device=new_instance.torch_device
)
# Ensure frame mask matches audio_in frames
if final_audio_in_frame_mask.shape[0] != t_frames:
# Pad or truncate to match
if final_audio_in_frame_mask.shape[0] < t_frames:
padding = torch.zeros(
t_frames - final_audio_in_frame_mask.shape[0],
dtype=torch.bool,
device=new_instance.torch_device,
)
final_audio_in_frame_mask = torch.cat(
[final_audio_in_frame_mask, padding], dim=0
)
else:
final_audio_in_frame_mask = final_audio_in_frame_mask[:t_frames]
new_instance.audio_in = safe_mask(
new_instance.audio_in,
final_audio_in_frame_mask,
)
new_instance.audio_out = safe_mask(
new_instance.audio_out,
final_audio_out_relative,
)
# Update audio_in_lens based on kept frames per segment
frame_offset = 0
for i in range(new_instance.audio_in_lens.shape[0]):
original_len = int(chat.audio_in_lens[i].item()) # Use original chat's lens
kept_frames = (
final_audio_in_frame_mask[frame_offset : frame_offset + original_len]
.sum()
.item()
)
new_instance.audio_in_lens[i] = kept_frames
frame_offset += original_len
new_instance.audio_in_lens = new_instance.audio_in_lens[
new_instance.audio_in_lens > 0
]
if chat.validate_from_chat:
if new_instance._audio_map.shape[0] != audio_mask_relative.sum().item():
raise ValueError(
"audio_map shape does not match number of audio tokens after filtering."
)
indices_in = -new_instance._audio_map[new_instance._audio_map < 0] - 1
if (
indices_in.numel() > 0
and indices_in.max() >= new_instance.audio_in.shape[1]
):
raise ValueError("audio_in index out of bounds after filtering.")
indices_out = new_instance._audio_map[new_instance._audio_map > 0] - 1
if (
indices_out.numel() > 0
and indices_out.max() >= new_instance.audio_out.shape[1]
):
raise ValueError("audio_out index out of bounds after filtering.")
new_instance.modality_flag = safe_mask(new_instance.modality_flag, full_mask)
return new_instance
@cached_property
def input_tokens(self) -> list[Tensor]:
text_mask = self.text_tokens_mask
audio_mask = self.audio_tokens_mask
# Total number of tokens
total_len = len(text_mask)
result: list[Tensor] = [torch.empty(0)] * total_len
a_idx = t_idx = 0
for i, is_audio in enumerate(audio_mask):
if is_audio:
token_idx = self.audio_tokens[a_idx]
if token_idx < 0: # audio in
result[i] = self.audio_in[..., -token_idx - 1].unsqueeze(-1)
else: # audio out
result[i] = self.audio_out[..., token_idx - 1].unsqueeze(-1)
a_idx += 1
else:
result[i] = self.text_tokens[t_idx].unsqueeze(-1)
t_idx += 1
return result
@cached_property
def tokens_modality_flag(self) -> Tensor:
modality_flag = torch.full_like(self.modality_flag[0], ModalityFlag.AUDIO)
modality_flag[self.modality_flag[0] == LFMModality.TEXT] = ModalityFlag.TEXT
return modality_flag
@cached_property
def text_tokens(self) -> Tensor:
return cast(Tensor, self.text[0])
@cached_property
def audio_tokens(self) -> Tensor: # return audio tokens map
return self._audio_map
def _decode_text(self, text_tokens: Tensor) -> str:
# Processor decode may return list[str] for batched inputs; normalize to str.
tt = text_tokens
if tt.ndim == 2 and tt.shape[0] == 1:
tt = tt[0]
elif tt.ndim > 1:
tt = tt.reshape(-1)
decoded = self.proc.text.decode(tt)
if isinstance(decoded, list):
return "".join(decoded)
return decoded
def _decode_audio(self, audio_tokens: Tensor) -> Tensor | None:
if len(audio_tokens.shape) == 1:
logger.debug(
"Resolving 1D audio token indices from _audio_map: "
"positive indices map to audio_out, negative indices map to audio_in."
)
sign = torch.sign(audio_tokens)
if sign.all(): # audio in
audio_tokens = self.audio_in[audio_tokens - 1]
elif not sign.any(): # audio out
audio_tokens = self.audio_out[-audio_tokens - 1]
else:
raise ValueError(
"audio_tokens should contain either only audio in or only audio out tokens."
)
# input tokens
if audio_tokens.shape[0] == LiquidAudioChat.AUDIO_IN_SHAPE:
logger.debug(
"Skipping audio input (audio_in) decoding: codec inverse is not supported "
"for encoded input representations. Shape: %s",
audio_tokens.shape,
)
# logger.warning("Decoding audio in tokens is not supported.")
return None
# audio out tokens
if audio_tokens.shape[0] == LiquidAudioChat.AUDIO_OUT_SHAPE:
logger.debug(
"Decoding audio output tokens to waveform using MIMI codec. "
"Shape: %s (codebooks, time steps)",
audio_tokens.shape,
)
mimi_codes = audio_tokens.unsqueeze(0)
# -validation/clamp of code indices
mimi_codes = mimi_codes.to(
dtype=torch.long, device=self.torch_device, non_blocking=True
)
# try to infer per-codebook sizes from quantizer internals
sizes: list[int] = []
try:
q = self.proc.mimi.quantizer
if (
hasattr(q, "vq")
and hasattr(q.vq, "layers")
and q.vq.layers is not None
):
for layer in cast(Iterable[Any], q.vq.layers):
codebook = getattr(layer, "_codebook", None) or getattr(
layer, "codebook", None
)
emb = getattr(codebook, "embedding", None)
if emb is None:
raise AttributeError("No embedding on codebook")
sizes.append(int(emb.shape[0]))
else:
# conservative fallback: assume 2048 entries per codebook
sizes = [2048] * mimi_codes.shape[1]
except Exception as e:
msg = (
f"Could not introspect codebook sizes ({e}). Falling back to 2048."
)
logger.warning(
"Could not introspect codebook sizes (%s). Falling back to 2048.", e
)
print(msg, file=sys.stderr)
sizes = [2048] * mimi_codes.shape[1]
num_codebooks = mimi_codes.shape[1]
for k in range(min(num_codebooks, len(sizes))):
n = sizes[k]
ck = mimi_codes[:, k, :] # (B, T)
# detect OOR
if (ck >= n).any() or (ck < 0).any():
mn = int(ck.min().item())
mx = int(ck.max().item())
msg = (
f"Audio code OOR on codebook {k}: min={mn} max={mx} "
f"valid=[0,{n}). Clamping."
)
logger.warning(
"Audio code OOR on codebook %d: min=%d max=%d valid=[0,%d). Clamping.",
k,
mn,
mx,
n,
)
print(msg, file=sys.stderr)
ck.clamp_(0, n - 1)
return cast(Tensor, self.proc.mimi.decode(mimi_codes).squeeze(0))
raise ValueError(
f"audio tokens first dimension should be either {LiquidAudioChat.AUDIO_OUT_SHAPE} "
f"(audio out) or {LiquidAudioChat.AUDIO_IN_SHAPE} (audio in)."
)
def _add_text(self, text: str) -> int:
starting_tokens_num = self.text.shape[1]
_ChatState.add_text(self, text)
return int(self.text.shape[1] - starting_tokens_num)
def _add_audio(self, waveform: Tensor, sample_rate: int) -> int:
starting_tokens_num = self.audio_in.shape[1]
_ChatState.add_audio(self, waveform, sample_rate)
delta_cols = int(self.audio_in.shape[1] - starting_tokens_num)
added_tokens_num = math.ceil(delta_cols / LiquidAudioChat.AUDIO_OUT_SHAPE)
# update audio map
self._audio_map = torch.cat(
[
self._audio_map,
-(
torch.arange(
starting_tokens_num // LiquidAudioChat.AUDIO_OUT_SHAPE,
starting_tokens_num // LiquidAudioChat.AUDIO_OUT_SHAPE
+ added_tokens_num,
dtype=torch.long,
device=self.torch_device,
)
+ 1
),
],
dim=0,
)
return added_tokens_num
def _append(
self,
text: Tensor,
audio_out: Tensor,
modality_flag: Tensor,
history_tracking_mode: ModelHistoryTrackingMode,
) -> tuple[int, int]:
starting_text_tokens_num = self.text[0].shape[0]
starting_audio_tokens_num = self.audio_out[0].shape[0]
if history_tracking_mode == ModelHistoryTrackingMode.TEXT:
audio_out = torch.empty(
(self.codebooks, 0), dtype=audio_out.dtype, device=audio_out.device
)
modality_flag = modality_flag[modality_flag == LFMModality.TEXT].unsqueeze(
0
)
elif history_tracking_mode == ModelHistoryTrackingMode.AUDIO:
text = torch.empty((1, 0), dtype=text.dtype, device=text.device)
modality_flag = modality_flag[modality_flag != LFMModality.TEXT].unsqueeze(
0
)
# else: keep both text and audio_out as is
_ChatState.append(self, text, audio_out, modality_flag)
# update audio map
self._audio_map = torch.cat(
[
self._audio_map,
torch.arange(
starting_audio_tokens_num,
self.audio_out.shape[1],
dtype=torch.long,
device=self.torch_device,
)
+ 1,
],
dim=0,
)
return (
self.text[0].shape[0] - starting_text_tokens_num,
self.audio_out[0].shape[0] - starting_audio_tokens_num,
)
def _new_turn(self, speaker: Role) -> None:
role: Literal["system", "user", "assistant"]
if speaker == Role.SYSTEM:
role = "system"
elif speaker == Role.USER:
role = "user"
else: # Role.ASSISTANT
role = "assistant"
_ChatState.new_turn(self, role)
def _end_turn(self) -> None:
_ChatState.end_turn(self)
def _get_tokens_sequences_to_exclude(
self, phrases_to_exclude: set[str]
) -> list[Tensor]:
token_sequences_to_exclude: list[Tensor] = []
for phrase in phrases_to_exclude:
token_ids = self.proc.text.encode(phrase, add_special_tokens=False)
token_sequences_to_exclude.append(
torch.tensor(token_ids, device=self.torch_device)
)
return token_sequences_to_exclude
def _get_relative_audio_masks(self) -> tuple[Tensor, Tensor]:
"""
Get relative audio in and out masks based on the modality flag
(relative to audio tokens only).
Returns:
tuple[Tensor, Tensor]: A tuple containing the relative audio in mask and audio out mask
"""
audio_in_mask = self.modality_flag[0] == LFMModality.AUDIO_IN
audio_out_mask = self.modality_flag[0] == LFMModality.AUDIO_OUT
audio_mask = audio_in_mask | audio_out_mask
audio_in_mask_relative = audio_mask[audio_mask].clone()
audio_in_mask_relative[audio_out_mask[audio_mask]] = False
audio_out_mask_relative = audio_mask[audio_mask].clone()
audio_out_mask_relative[audio_in_mask[audio_mask]] = False
return audio_in_mask_relative, audio_out_mask_relative