Source code for mllm_shap.shap.similarity

# pylint: disable=too-few-public-methods
"""Embedding similarity calculations for SHAP explanations."""

import hashlib
from typing import cast

import torch
from torch import Tensor
from sklearn.feature_extraction.text import TfidfVectorizer

from ..connectors.base.model_response import ModelResponse
from .base.similarity import BaseEmbeddingSimilarity


[docs] class EuclideanSimilarity(BaseEmbeddingSimilarity): """ Euclidean similarity calculation, used in implementation of U4 utility function from the paper. """ def __call__(self, base: Tensor, other: Tensor) -> Tensor: """ Calculate the Euclidean similarity between the base embedding and other embeddings. Args: base: The base embedding tensor, shape [embedding_dim]. other: The other embeddings tensor to compare against, shape [num_embeddings, embedding_dim]. """ # calculate euclidean distances distances = torch.norm(other - base.unsqueeze(0), dim=-1) # convert distances to similarities similarities = 1 / (1 + distances) return cast(Tensor, similarities)
[docs] class CosineSimilarity(BaseEmbeddingSimilarity): """ Cosine similarity calculation, used in implementation of U1 and U2 utility functions from the paper. """ def __call__(self, base: Tensor, other: Tensor) -> Tensor: """ Calculate the Cosine similarity between the base embedding and other embeddings. Args: base: The base embedding tensor, shape [embedding_dim]. other: The other embeddings tensor to compare against, shape [num_embeddings, embedding_dim]. """ # normalize embeddings with epsilon to avoid division by zero eps = 1e-8 base_norm_val = base.norm(dim=-1, keepdim=True).clamp(min=eps) other_norm_val = other.norm(dim=-1, keepdim=True).clamp(min=eps) base_norm = base / base_norm_val other_norm = other / other_norm_val return cast(Tensor, (other_norm * base_norm.unsqueeze(0)).sum(dim=-1))
[docs] class TfIdfCosineSimilarity(BaseEmbeddingSimilarity): """ TF-IDF weighted Cosine similarity calculation, used in implementation of U3 utility function from the paper. """ operates_on_embeddings: bool = False __tokenize_map: dict[bytes, int] = {} __tokenize_counter: int = 0 __vectorizer: TfidfVectorizer def __init__(self) -> None: """Initialize the TF-IDF vectorizer.""" self.__vectorizer = TfidfVectorizer(analyzer=lambda x: x) def __call__(self, base: ModelResponse, other: list[ModelResponse]) -> Tensor: """ Calculate the TF-IDF weighted Cosine similarity between the base response and other responses. Args: base: The base model response. other: The list of other model responses to compare against. Returns: A tensor containing the TF-IDF weighted Cosine similarities. """ # check if other[0] == base if not ( torch.equal(base.generated_text_tokens, other[0].generated_text_tokens) and torch.equal(base.generated_audio_tokens, other[0].generated_audio_tokens) ): raise ValueError("The first element of 'other' must be equal to 'base' tensor.") generated_text_tokens_hashes = [self.__tokenize(tensor=o.generated_text_tokens) for o in other] generated_audio_tokens_hashes = [self.__tokenize(tensor=o.generated_audio_tokens) for o in other] token_hashes_tensors = [ torch.cat((text_hash, audio_hash), dim=0) for text_hash, audio_hash in zip(generated_text_tokens_hashes, generated_audio_tokens_hashes) ] tf_idfs = Tensor( self.__vectorizer.fit_transform([o.numpy() for o in token_hashes_tensors]).toarray(), ).to(base.generated_text_tokens.device) return CosineSimilarity()(base=tf_idfs[0], other=tf_idfs) def __tokenize(self, tensor: Tensor) -> Tensor: """ Tokenize the input tensor into a unique integer representation. Args: tensor: The input tensor to tokenize. Returns: A tensor containing the unique integer representation of the input tensor. """ tensor_cpu = tensor.detach().to("cpu").contiguous() result = [] for el in tensor_cpu: # Hash both data and metadata h = hashlib.sha256() h.update(el.numpy().tobytes()) h.update(str(tuple(el.shape)).encode()) h.update(str(tensor_cpu.dtype).encode()) key = h.digest() # assign incremental id if unseen if key not in self.__tokenize_map: self.__tokenize_map[key] = self.__tokenize_counter self.__tokenize_counter += 1 result.append(self.__tokenize_map[key]) return Tensor(result)