"""Transformers text-only model connector."""
from __future__ import annotations
import warnings
from copy import deepcopy
from typing import Any, cast
import torch
from torch import Tensor
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from ..base.chat import BaseMllmChat
from ..base.model import BaseMllmModel
from ..base.model_response import ModelResponse
from ..config import ModelConfig
from ..enums import ModelHistoryTrackingMode, Role, ModalityFlag
from .chat import TransformersTextChat
from .config import CONFIG
[docs]
class TransformersCausalText(BaseMllmModel):
"""
Connector for classic Hugging Face causal LMs (text-only).
Fields:
processor (PreTrainedTokenizerBase): the tokenizer
model (PreTrainedModel): the causal LM (AutoModelForCausalLM)
"""
processor: PreTrainedTokenizerBase
model: PreTrainedModel
_KW_HISTORY_TRACKING_MODE = "history_tracking_mode"
_TOKEN_EMB_RANK = 2
def __init__(self, device: torch.device, **kwargs: Any) -> None:
# Disallow overriding these to keep parity with LiquidAudio pattern.
forbidden = {"config", "model", "processor"}
if any(k in kwargs for k in forbidden):
raise ValueError(
"Do not pass 'config', 'model', or 'processor'—they are set automatically."
)
tokenizer = cast(Any, AutoTokenizer).from_pretrained(
CONFIG.repo_id,
revision=CONFIG.revision,
) # nosec: B615 - pinned to immutable commit
_model = AutoModelForCausalLM.from_pretrained(
CONFIG.repo_id,
revision=CONFIG.revision,
load_in_4bit=True,
) # nosec: B615
model = cast(PreTrainedModel, _model)
cast(Any, model).to(device)
cast(Any, model).eval()
# Force text-only history tracking
if (
self._KW_HISTORY_TRACKING_MODE in kwargs
and kwargs[self._KW_HISTORY_TRACKING_MODE] != ModelHistoryTrackingMode.TEXT
):
warnings.warn(
"Non-TEXT history tracking requested but this connector is text-only. Forcing TEXT mode.",
stacklevel=2,
)
kwargs[self._KW_HISTORY_TRACKING_MODE] = ModelHistoryTrackingMode.TEXT
super().__init__(
config=CONFIG,
device=device,
processor=tokenizer,
model=model,
history_tracking_mode=kwargs.pop(
self._KW_HISTORY_TRACKING_MODE, ModelHistoryTrackingMode.TEXT
),
)
# Set reasonable defaults for EOS if missing
if (
getattr(self.processor, "pad_token_id", None) is None
and getattr(self.processor, "eos_token_id", None) is not None
and getattr(self.processor, "eos_token", None) is not None
):
# guarded, no broad-except
self.processor.pad_token = self.processor.eos_token
gen_cfg = self.model.generation_config
if not isinstance(gen_cfg, GenerationConfig):
gen_cfg = cast(Any, GenerationConfig)()
setattr(self.model, "generation_config", gen_cfg)
if (
getattr(gen_cfg, "pad_token_id", None) is None
and self.processor.pad_token_id is not None
):
gen_cfg.pad_token_id = self.processor.pad_token_id
if (
getattr(gen_cfg, "eos_token_id", None) is None
and self.processor.eos_token_id is not None
):
gen_cfg.eos_token_id = self.processor.eos_token_id
[docs]
def get_new_chat(self, **kwargs: Any) -> TransformersTextChat:
kwargs = dict(kwargs or {})
kwargs.pop("device", None)
kwargs["tokenizer"] = self.processor
return TransformersTextChat(device=self.device, **kwargs)
[docs]
def generate(
self,
chat: BaseMllmChat,
max_new_tokens: int = 128,
model_config: ModelConfig = ModelConfig(),
keep_history: bool = False,
) -> ModelResponse:
# Defensive copy to avoid cross-call mutation via shared default object.
model_config = model_config.model_copy(deep=True)
super().generate(
chat=chat,
max_new_tokens=max_new_tokens,
model_config=model_config,
keep_history=keep_history,
)
# Enforce text-only semantics and surface warnings for audio knobs
if (
model_config.audio_temperature is not None
or model_config.audio_top_k is not None
):
warnings.warn(
"Audio generation parameters were provided but this connector is text-only; \
audio settings are ignored.",
stacklevel=2,
)
# Copy chat (immutable input contract), mark assistant reply turn
chat = deepcopy(chat)
chat.new_turn(Role.ASSISTANT)
# Build input ids from chat history (pure text)
input_ids = chat.text_tokens.unsqueeze(0) # [1, prompt_len]
input_ids = input_ids.to(dtype=torch.long, device=self.device)
prompt_len = int(input_ids.shape[1])
# Explicit attention mask avoids warning about attention mask not set
# and is correct for unpadded 1xT inputs
attention_mask = torch.ones_like(
input_ids, dtype=torch.long, device=self.device
)
do_sample = (
model_config.text_temperature is not None
and model_config.text_temperature > 0.0
)
temperature: float | None = (
float(model_config.text_temperature)
if do_sample and model_config.text_temperature is not None
else None
)
top_k: int | None = (
int(model_config.text_top_k)
if do_sample and model_config.text_top_k is not None
else None
)
gen_out = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
return_dict_in_generate=True,
pad_token_id=self.processor.pad_token_id,
eos_token_id=self.processor.eos_token_id,
)
sequences: Tensor = gen_out.sequences # [1, prompt_len + seq_len]
generated = (
sequences[0, prompt_len:]
if sequences.shape[1] > prompt_len
else sequences.new_empty((0,))
)
generated = generated.to(dtype=torch.long, device=self.device) # [seq_len]
# All generated tokens are TEXT
modality_flag = torch.full(
(generated.shape[0],),
ModalityFlag.TEXT,
dtype=torch.long,
device=self.device,
)
# History update
if keep_history:
# For API parity with LiquidAudio: pass [1, T] tensors
text_tokens_2d = generated.unsqueeze(0) # [1, seq_len]
empty_audio = torch.empty(
(0, 0), dtype=torch.long, device=self.device
) # [0, 0]
self._set_chat_history(chat, text_tokens_2d, empty_audio, modality_flag)
return ModelResponse(
chat=chat if keep_history else None,
generated_text_tokens=generated, # [seq_len]
generated_audio_tokens=torch.empty(
(0, 0), dtype=torch.long, device=self.device
), # [0, 0]
generated_modality_flag=modality_flag, # [seq_len]
)
# -- embeddings API --
[docs]
def get_static_embeddings(self, responses: list[ModelResponse]) -> list[Tensor]:
super().get_static_embeddings(responses=responses)
emb_layer = self.model.get_input_embeddings() # standard HF API
static_embeddings: list[Tensor] = []
for response in responses:
ids = response.generated_text_tokens.to(
device=self.device, dtype=torch.long
).unsqueeze(0) # [1, T]
# Shape: [1, T, hidden]
emb = emb_layer(ids)
static_embeddings.append(emb.squeeze(0)) # [T, hidden]
return static_embeddings
def _get_contextual_embeddings(
self, static_embeddings: list[Tensor]
) -> list[Tensor]:
contextual: list[Tensor] = []
for emb in static_embeddings:
if emb.dim() == self._TOKEN_EMB_RANK:
emb = emb.unsqueeze(0) # [1, T, hidden]
# Call base model to obtain last_hidden_state (see HF outputs contract)
base = getattr(self.model, "base_model", self.model)
outputs = base(inputs_embeds=emb, use_cache=False)
# Shape: [1, T, hidden]
contextual.append(outputs.last_hidden_state.squeeze(0))
return contextual