Source code for mllm_shap.connectors.base.model

"""Base model connector class."""

from abc import ABC, abstractmethod
from logging import Logger
from typing import Any, cast

import torch
from torch import Tensor

from ...utils.logger import get_logger
from ...utils.other import raise_connector_error
from ..config import HuggingFaceModelConfig, ModelConfig
from ..enums import ModelHistoryTrackingMode
from ._validators import BaseModelConfig, BaseModelGenerateConfig
from .chat import BaseMllmChat
from .model_response import ModelResponse

logger: Logger = get_logger(__name__)


# pylint: disable=duplicate-code
[docs] class BaseMllmModel(ABC): """Base class for model connectors.""" config: HuggingFaceModelConfig """The model configuration.""" device: torch.device """The device to run the model on.""" processor: Any """The model processor (tokenizer).""" model: Any """The model instance.""" history_tracking_mode: ModelHistoryTrackingMode """The mode for tracking chat history.""" # pylint: disable=too-many-positional-arguments,too-many-arguments def __init__( self, config: HuggingFaceModelConfig, device: torch.device, processor: Any, model: Any, history_tracking_mode: ModelHistoryTrackingMode = ModelHistoryTrackingMode.TEXT, ) -> None: """ Initialize the model connector. Args: config: The model configuration. device: The device to run the model on. processor: The model processor. model: The model instance. history_tracking_mode: The mode for tracking chat history. """ # validation __config = BaseModelConfig( config=config, device=device, processor=processor, model=model, history_tracking_mode=history_tracking_mode, ) self.device = __config.device self.config = __config.config self.processor = __config.processor self.model = __config.model self.history_tracking_mode = __config.history_tracking_mode
[docs] @abstractmethod def get_new_chat(self) -> BaseMllmChat: """Get a new chat state for the model."""
[docs] @abstractmethod def generate( # type: ignore[return] self, chat: BaseMllmChat, max_new_tokens: int = 128, model_config: ModelConfig = ModelConfig(), keep_history: bool = False, ) -> ModelResponse: """ Generate audio based on the current chat state. Args: chat: The current chat state. max_new_tokens: The maximum number of new tokens to generate (default is 20). model_config: Additional model configuration parameters. keep_history: Whether to return chat state with full history or only generated content. Returns: ModelResponse: The updated chat state after generation. """ logger.debug("Generating audio with max_new_tokens=%d, keep_history=%s", max_new_tokens, keep_history) # validation _ = BaseModelGenerateConfig( max_new_tokens=max_new_tokens, model_config_=model_config, keep_history=keep_history, )
[docs] @abstractmethod def get_static_embeddings(self, responses: list[ModelResponse]) -> list[Tensor]: # type: ignore[return] """ Get static embeddings for the current chat state. Args: responses: The model responses to get embeddings for. Returns: The static embeddings for the text and audio tokens. Raises: ValueError: If responses is not a list of ModelResponse. """ logger.debug("Getting static embeddings.") if not isinstance(responses, list) or not all(isinstance(r, ModelResponse) for r in responses): raise ValueError(f"responses must be a list of ModelResponse, got {type(responses)}")
[docs] def get_contextual_embeddings( self, *args: Any, static_embeddings: list[Tensor] | None = None, **kwargs: Any ) -> list[Tensor]: """ Get contextual embeddings for the current chat state. Args: static_embeddings: Precomputed static embeddings (if any). *args: Additional positional arguments for :func:`get_static_embeddings`. Used if static_embeddings is None. **kwargs: Additional keyword arguments for :func:`get_static_embeddings`. Used if static_embeddings is None. Returns: The context embeddings for the text and audio tokens, same format as in :func:`get_static_embeddings`. Raises: ValueError: If static_embeddings is not an instance of Tensor. """ logger.debug("Getting contextual embeddings.") if static_embeddings is None: static_embeddings = self.get_static_embeddings(*args, **kwargs) if not isinstance(static_embeddings, list): raise ValueError(f"static_embeddings must be an instance of list, got {type(static_embeddings)}") for emb in static_embeddings: if not isinstance(emb, Tensor): raise ValueError(f"Each item in static_embeddings must be an instance of Tensor, got {type(emb)}") with torch.no_grad(): return cast(list[Tensor], raise_connector_error(self._get_contextual_embeddings, static_embeddings))
@abstractmethod def _get_contextual_embeddings(self, static_embeddings: list[Tensor]) -> list[Tensor]: """ Get contextual embeddings for the current chat state. Args: static_embeddings: Precomputed static embeddings. Returns: The contextual embeddings for the text and audio tokens. """ def _set_chat_history( self, chat: BaseMllmChat, text_tokens: Tensor, audio_tokens: Tensor, modality_flag: Tensor, ) -> None: """ Set the chat history with provided text and audio tokens. Args: chat: The chat instance to update. text_tokens: The text tokens to set. audio_tokens: The audio tokens to set. modality_flag: The modality flags corresponding to the tokens. """ logger.debug( "Setting chat history with text tokens (%d), audio tokens (%d), modality flags (%d).", text_tokens.shape[0], audio_tokens.shape[0], modality_flag.shape[0], ) chat.append( text=text_tokens, audio_out=audio_tokens, modality_flag=modality_flag, history_tracking_mode=self.history_tracking_mode, ) chat.end_turn()