Source code for mllm_shap.shap.base.similarity

"""Base class for embedding similarity calculations."""

from abc import ABC, abstractmethod
from typing import Any

from torch import Tensor


# pylint: disable=too-few-public-methods
[docs] class BaseEmbeddingSimilarity(ABC): """Base class for embedding similarity calculations.""" operates_on_embeddings: bool = True """ Indicates that the similarity operates on embeddings. If False, it operates on raw tokens. Used to resolve input to :func:`__call__`. """ @abstractmethod def __call__(self, base: Any, other: Any) -> Tensor: """ Compute similarity between two embeddings. Args: base: Base object. other: Other objects to compare against the base. Returns: Similarity scores. """