Source code for mllm_shap.shap.base.embeddings
"""Base class for embedding calculation reduction strategies."""
from abc import ABC, abstractmethod
import torch
from torch import Tensor
from ...connectors.base.model_response import ModelResponse
# pylint: disable=too-few-public-methods
[docs]
class BaseEmbeddingReducer(ABC):
"""Base class for embedding reduction strategies."""
n: int | None
"""Maximum number of embeddings to sample before reduction. None means no sampling."""
def __init__(self, n: int | None = None):
"""
Initialize the BaseEmbeddingReducer.
Args:
n: Maximum number of embeddings to sample before reduction. None means no sampling.
Raises:
ValueError: If n is not None or a positive integer.
"""
if not (n is None or n > 0):
raise ValueError("n must be None or a positive integer.")
self.n = n
def _prepare(self, embeddings: list[Tensor]) -> list[Tensor]:
"""
Prepare the embeddings for reduction.
Args:
embeddings: The input embeddings to be reduced of size (N, d, k),
where n is number of samples, d is single embedding vector size,
k is number of vectors per sample.
Returns:
The prepared embeddings.
Raises:
ValueError: If any embedding is not a Tensor.
"""
for i, emb in enumerate(embeddings):
if not isinstance(emb, Tensor):
raise ValueError(f"Embedding at index {i} is not a Tensor.")
if self.n is None or self.n >= emb.shape[-1]:
continue
indices = torch.randperm(emb.shape[0])[: self.n]
embeddings[i] = emb[..., indices]
return embeddings
@abstractmethod
def __call__(self, embeddings: list[Tensor]) -> Tensor:
"""
Reduce the embeddings according to the specific strategy.
Args:
embeddings: The input embeddings to be reduced of size (N, d, k),
where n is number of samples, d is single embedding vector size,
k is number of vectors per sample.
Returns:
The reduced embeddings of size (N, d)
"""
[docs]
class BaseExternalEmbedding(ABC):
"""Base class for external embeddings."""
@abstractmethod
def __call__(self, responses: list[ModelResponse]) -> list[Tensor]:
"""
Get the external embeddings for the given chat.
Args:
responses: The model responses to get embeddings for.
Returns:
The external embeddings for the text and audio tokens.
"""