Source code for mllm_shap.connectors.transformers_text.model

"""Transformers text-only model connector."""

from __future__ import annotations

import warnings
from copy import deepcopy
from typing import Any, cast

import torch
from torch import Tensor
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase

from ..base.chat import BaseMllmChat
from ..base.model import BaseMllmModel
from ..base.model_response import ModelResponse
from ..config import ModelConfig
from ..enums import ModelHistoryTrackingMode, Role, ModalityFlag
from .chat import TransformersTextChat
from .config import CONFIG


[docs] class TransformersCausalText(BaseMllmModel): """ Connector for classic Hugging Face causal LMs (text-only). Fields: processor (PreTrainedTokenizerBase): the tokenizer model (PreTrainedModel): the causal LM (AutoModelForCausalLM) """ processor: PreTrainedTokenizerBase model: PreTrainedModel _KW_HISTORY_TRACKING_MODE = "history_tracking_mode" _TOKEN_EMB_RANK = 2 def __init__(self, device: torch.device, **kwargs: Any) -> None: # Disallow overriding these to keep parity with LiquidAudio pattern. forbidden = {"config", "model", "processor"} if any(k in kwargs for k in forbidden): raise ValueError("Do not pass 'config', 'model', or 'processor'—they are set automatically.") tokenizer = cast(Any, AutoTokenizer).from_pretrained( CONFIG.repo_id, revision=CONFIG.revision, ) # nosec: B615 - pinned to immutable commit _model = AutoModelForCausalLM.from_pretrained( CONFIG.repo_id, revision=CONFIG.revision, load_in_4bit=True, ) # nosec: B615 model = cast(PreTrainedModel, _model) cast(Any, model).to(device) cast(Any, model).eval() # Force text-only history tracking if self._KW_HISTORY_TRACKING_MODE in kwargs and \ kwargs[self._KW_HISTORY_TRACKING_MODE] != ModelHistoryTrackingMode.TEXT: warnings.warn( "Non-TEXT history tracking requested but this connector is text-only. Forcing TEXT mode.", stacklevel=2, ) kwargs[self._KW_HISTORY_TRACKING_MODE] = ModelHistoryTrackingMode.TEXT super().__init__( config=CONFIG, device=device, processor=tokenizer, model=model, history_tracking_mode=kwargs.pop(self._KW_HISTORY_TRACKING_MODE, ModelHistoryTrackingMode.TEXT), ) # Set reasonable defaults for EOS if missing if ( getattr(self.processor, "pad_token_id", None) is None and getattr(self.processor, "eos_token_id", None) is not None and getattr(self.processor, "eos_token", None) is not None ): # guarded, no broad-except self.processor.pad_token = self.processor.eos_token gen_cfg = self.model.generation_config if not isinstance(gen_cfg, GenerationConfig): gen_cfg = cast(Any, GenerationConfig)() setattr(self.model, "generation_config", gen_cfg) if getattr(gen_cfg, "pad_token_id", None) is None and self.processor.pad_token_id is not None: gen_cfg.pad_token_id = self.processor.pad_token_id if getattr(gen_cfg, "eos_token_id", None) is None and self.processor.eos_token_id is not None: gen_cfg.eos_token_id = self.processor.eos_token_id
[docs] def get_new_chat(self, **kwargs: Any) -> TransformersTextChat: kwargs = dict(kwargs or {}) kwargs.pop("device", None) kwargs["tokenizer"] = self.processor return TransformersTextChat(device=self.device, **kwargs)
# pylint: disable=too-many-locals
[docs] def generate( self, chat: BaseMllmChat, max_new_tokens: int = 128, model_config: ModelConfig = ModelConfig(), keep_history: bool = False, ) -> ModelResponse: super().generate(chat=chat, max_new_tokens=max_new_tokens, model_config=model_config, keep_history=keep_history) # Enforce text-only semantics and surface warnings for audio knobs if model_config.audio_temperature is not None or model_config.audio_top_k is not None: warnings.warn( "Audio generation parameters were provided but this connector is text-only; \ audio settings are ignored.", stacklevel=2, ) # Copy chat (immutable input contract), mark assistant reply turn chat = deepcopy(chat) chat.new_turn(Role.ASSISTANT) # Build input ids from chat history (pure text) input_ids = chat.text_tokens.unsqueeze(0) # [1, prompt_len] input_ids = input_ids.to(dtype=torch.long, device=self.device) prompt_len = int(input_ids.shape[1]) # Explicit attention mask avoids warning about attention mask not set # and is correct for unpadded 1xT inputs attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=self.device) do_sample = model_config.text_temperature is not None and model_config.text_temperature > 0.0 temperature: float | None = ( float(model_config.text_temperature) if do_sample and model_config.text_temperature is not None else None ) top_k: int | None = ( int(model_config.text_top_k) if do_sample and model_config.text_top_k is not None else None ) gen_out = self.model.generate( # type: ignore[operator] input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_k=top_k, return_dict_in_generate=True, pad_token_id=self.processor.pad_token_id, eos_token_id=self.processor.eos_token_id, ) sequences: Tensor = gen_out.sequences # [1, prompt_len + seq_len] generated = sequences[0, prompt_len:] if sequences.shape[1] > prompt_len else sequences.new_empty((0,)) generated = generated.to(dtype=torch.long, device=self.device) # [seq_len] # All generated tokens are TEXT modality_flag = torch.full((generated.shape[0],), ModalityFlag.TEXT, dtype=torch.long, device=self.device) # History update if keep_history: # For API parity with LiquidAudio: pass [1, T] tensors text_tokens_2d = generated.unsqueeze(0) # [1, seq_len] empty_audio = torch.empty((0, 0), dtype=torch.long, device=self.device) # [0, 0] self._set_chat_history(chat, text_tokens_2d, empty_audio, modality_flag) return ModelResponse( chat=chat if keep_history else None, generated_text_tokens=generated, # [seq_len] generated_audio_tokens=torch.empty((0, 0), dtype=torch.long, device=self.device), # [0, 0] generated_modality_flag=modality_flag, # [seq_len] )
# -- embeddings API --
[docs] def get_static_embeddings(self, responses: list[ModelResponse]) -> list[Tensor]: super().get_static_embeddings(responses=responses) emb_layer = self.model.get_input_embeddings() # standard HF API static_embeddings: list[Tensor] = [] for response in responses: ids = response.generated_text_tokens.to(device=self.device, dtype=torch.long).unsqueeze(0) # [1, T] # Shape: [1, T, hidden] emb = emb_layer(ids) static_embeddings.append(emb.squeeze(0)) # [T, hidden] return static_embeddings
def _get_contextual_embeddings(self, static_embeddings: list[Tensor]) -> list[Tensor]: contextual: list[Tensor] = [] for emb in static_embeddings: if emb.dim() == self._TOKEN_EMB_RANK: emb = emb.unsqueeze(0) # [1, T, hidden] # Call base model to obtain last_hidden_state (see HF outputs contract) base = getattr(self.model, "base_model", self.model) outputs = base(inputs_embeds=emb, use_cache=False) # Shape: [1, T, hidden] contextual.append(outputs.last_hidden_state.squeeze(0)) return contextual