Source code for mllm_shap.connectors.transformers_text.chat

"""Chat state for text-only Transformers causal models."""

from collections.abc import Callable
import warnings
from copy import deepcopy
from functools import cached_property

import torch
from torch import Tensor

from transformers import PreTrainedTokenizerBase

from ...utils.other import safe_mask
from ..base.chat import BaseMllmChat
from ..base.filters import TokenFilter
from ..enums import ModalityFlag, ModelHistoryTrackingMode, Role, SystemRolesSetup


[docs] class TransformersTextChat(BaseMllmChat): """ Chat state for text-only causal LMs. Stores only TEXT token IDs. AUDIO is unsupported and will warn+no-op. """ tokenizer: PreTrainedTokenizerBase _text_ids: Tensor _TWO_DIMS: int = 2 _SINGLE_BATCH: int = 1 _SHARED_ATTRIBUTES: frozenset[str] = frozenset({ "tokenizer", # Large read-only object, safe to share across copies }) # pylint: disable=too-many-arguments,too-many-positional-arguments def __init__( self, device: torch.device, tokenizer: PreTrainedTokenizerBase, empty_turn_sequences: set[str] | None = None, token_filter: TokenFilter | None = None, system_roles_setup: SystemRolesSetup | None = None, get_new_chat_callable: Callable[..., "TransformersTextChat"] | None = None, ) -> None: empty_turn_sequences = empty_turn_sequences or set() self.tokenizer = tokenizer super().__init__( device=device, empty_turn_sequences=empty_turn_sequences, token_filter=token_filter, system_roles_setup=system_roles_setup, get_new_chat_callable=get_new_chat_callable, # type: ignore[arg-type] ) self._text_ids = torch.empty(0, dtype=torch.long, device=device)
[docs] def apply_text_mask(self, text_mask_relative: Tensor) -> None: """Apply a relative text mask to this chat instance (public on purpose).""" self._text_ids = safe_mask(self._text_ids, text_mask_relative) self.text_tokens_no_system_mask = safe_mask( self.text_tokens_no_system_mask, text_mask_relative )
@classmethod def _set_new_instance( cls, full_mask: Tensor, text_mask_relative: Tensor, audio_mask_relative: Tensor, # unused (no audio) chat: "TransformersTextChat", # type: ignore[override] ) -> "TransformersTextChat": new_instance: "TransformersTextChat" = deepcopy(chat) new_instance.apply_text_mask(text_mask_relative) # full input token mask affects only text here # token_roles/turns are handled in BaseMllmChat._after_add via refresh() return new_instance @cached_property def input_tokens(self) -> list[Tensor]: # One Tensor per token, each shaped [1], to match BaseMllmChat expectations. return [tid.unsqueeze(0) for tid in self._text_ids] @cached_property def tokens_modality_flag(self) -> Tensor: # All tokens are TEXT return torch.full( (self._text_ids.shape[0],), ModalityFlag.TEXT, dtype=torch.long, device=self.torch_device, ) @cached_property def text_tokens(self) -> Tensor: return self._text_ids @cached_property def audio_tokens(self) -> Tensor: # No audio tokens in this connector return torch.empty(0, dtype=torch.long, device=self.torch_device) def _decode_text(self, text_tokens: Tensor) -> str: # Accept shape [T] or [1]; always return a single string. flat = text_tokens.detach().to("cpu").reshape(-1) raw_ids = flat.tolist() ids: list[int] = [int(raw_ids)] if isinstance(raw_ids, int) else [int(x) for x in raw_ids] decoded = self.tokenizer.decode(ids, skip_special_tokens=False) if isinstance(decoded, list): return "".join(decoded) return decoded def _decode_audio(self, audio_tokens: Tensor) -> Tensor | None: # pragma: no cover - unsupported return None # decoding audio is impossible here def _add_text(self, text: str) -> int: ids = self.tokenizer.encode(text, add_special_tokens=False) if len(ids) == 0: return 0 ids_t = torch.tensor(ids, dtype=torch.long, device=self.torch_device) start = self._text_ids.shape[0] self._text_ids = torch.cat([self._text_ids, ids_t], dim=0) return int(self._text_ids.shape[0] - start) def _add_audio(self, waveform: Tensor, sample_rate: int) -> int: # pragma: no cover warnings.warn( "Audio input is not supported by the TransformersText connector. Ignoring provided audio.", stacklevel=2, ) return 0 def _append( self, text: Tensor, audio_out: Tensor, # ignored for text-only modality_flag: Tensor, # ignored for text-only history_tracking_mode: ModelHistoryTrackingMode, ) -> tuple[int, int]: if history_tracking_mode == ModelHistoryTrackingMode.AUDIO: warnings.warn( "Requested AUDIO-only history tracking is not supported by the text-only connector. " "No tokens appended.", stacklevel=2, ) return 0, 0 # Expect [1, T]; accept other simple shapes if text.dim() == self._TWO_DIMS and text.shape[0] == self._TWO_DIMS: text = text.squeeze(0) elif text.dim() == 0: text = text.unsqueeze(0) if text.dim() != 1: text = text.reshape(-1) text = text.to(dtype=torch.long, device=self.torch_device) start = self._text_ids.shape[0] self._text_ids = torch.cat([self._text_ids, text], dim=0) # IMPORTANT: In text-only path there are no audio tokens; BaseMllmChat won’t refresh caches. # Do it here to keep tokens_modality_flag / input_tokens in sync with masks. self.refresh(full=True) return int(self._text_ids.shape[0] - start), 0 def _new_turn(self, speaker: Role) -> None: # No special turn markers injected for generic causal LMs. # Token accounting is handled in Base via _after_add. return def _end_turn(self) -> None: return def _get_tokens_sequences_to_exclude(self, phrases_to_exclude: set[str]) -> list[Tensor]: seqs: list[Tensor] = [] for phrase in phrases_to_exclude: ids = self.tokenizer.encode(phrase, add_special_tokens=False) seqs.append(torch.tensor(ids, dtype=torch.long, device=self.torch_device)) return seqs