Source code for mllm_shap.shap.embeddings

# pylint: disable=too-few-public-methods

"""Embedding calculation and reduction strategies for SHAP explanations."""

import re
from typing import Any, List, Sequence, cast

from torch import Tensor
import torch
from transformers import (
    AutoModel,
    AutoTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedModel
)

from ..connectors.base.model_response import ModelResponse
from .base.embeddings import BaseEmbeddingReducer, BaseExternalEmbedding


[docs] class ZeroReducer(BaseEmbeddingReducer): """Dummy reducer that returns embeddings unchanged.""" def __call__(self, embeddings: list[Tensor]) -> Tensor: embeddings = self._prepare(embeddings) shapes = [tuple(e.shape) for e in embeddings] if len(set(shapes)) != 1: raise ValueError(f"All embeddings must have the same shape for ZeroReducer. " f"Got shapes: {shapes}") return torch.stack(embeddings, dim=0)
[docs] class MeanReducer(BaseEmbeddingReducer): """Reducer that computes the mean of embeddings.""" def __call__(self, embeddings: list[Tensor]) -> Tensor: embeddings = self._prepare(embeddings) for i, emb in enumerate(embeddings): embeddings[i] = emb.mean(dim=0) return torch.stack(embeddings, dim=0)
[docs] class MaxReducer(BaseEmbeddingReducer): """Reducer that computes the max of embeddings.""" def __call__(self, embeddings: list[Tensor]) -> Tensor: embeddings = self._prepare(embeddings) for i, emb in enumerate(embeddings): embeddings[i] = emb.max(dim=0).values return torch.stack(embeddings, dim=0)
[docs] class MinReducer(BaseEmbeddingReducer): """Reducer that computes the min of embeddings.""" def __call__(self, embeddings: list[Tensor]) -> Tensor: embeddings = self._prepare(embeddings) for i, emb in enumerate(embeddings): embeddings[i] = emb.min(dim=0).values return torch.stack(embeddings, dim=0)
[docs] class SumReducer(BaseEmbeddingReducer): """Reducer that computes the sum of embeddings.""" def __call__(self, embeddings: list[Tensor]) -> Tensor: embeddings = self._prepare(embeddings) for i, emb in enumerate(embeddings): embeddings[i] = emb.sum(dim=0) return torch.stack(embeddings, dim=0)
[docs] class FirstReducer(BaseEmbeddingReducer): """ Reducer that selects the first embedding. :attr:`n` parameter is ignored in this reducer. """ def __call__(self, embeddings: list[Tensor]) -> Tensor: embeddings = self._prepare(embeddings) for i, emb in enumerate(embeddings): embeddings[i] = emb[..., 0, :] return torch.stack(embeddings, dim=0)
[docs] class CustomEmbedding(BaseExternalEmbedding): # pylint: disable=too-many-instance-attributes """ External embeddings using a **local** encoder model (e.g., E5/SBERT). For each :class:`ModelResponse`, we: 1) take ``generated_text_tokens`` (shape [T]), 2) decode **each token id** with the **generation tokenizer** to a short text piece, 3) embed each piece independently with the **embedding encoder**, 4) return a tensor of shape **[T, hidden]** per response (aligned 1:1 with tokens). """ tokenizer_decode: PreTrainedTokenizerBase emb_tokenizer: PreTrainedTokenizerBase emb_model: PreTrainedModel device: torch.device max_length: int batch_size: int l2_normalize: bool _hidden_size: int # pylint: disable=too-many-arguments def __init__( self, *, generation_tokenizer: PreTrainedTokenizerBase, embed_model_id: str, embed_revision: str, device: torch.device, max_length: int = 64, batch_size: int = 64, l2_normalize: bool = True, local_files_only: bool = True, ) -> None: """ Args: generation_tokenizer: Tokenizer used by the **generation LM** (for token-id -> text decoding). embed_model_id: Local HF model id/path for the **encoder** (e.g., ``"intfloat/e5-base-v2"``). embed_revision: **40-hex commit SHA** for the encoder model (immutable). device: Torch device for inference. max_length: Max tokens per decoded piece for the embedding tokenizer. batch_size: Per-forward batch size for encoder inference. l2_normalize: L2-normalize embeddings. local_files_only: Enforce local loading (no network); set False if you allow online fetch. """ # Security: enforce immutable commit SHA (satisfies Bandit B615 when loading with revision). if not isinstance(embed_revision, str) or not re.fullmatch(r"[0-9a-f]{40}", embed_revision): raise ValueError("embed_revision must be a 40-character hex commit SHA.") self.tokenizer_decode = generation_tokenizer self.emb_tokenizer = cast(Any, AutoTokenizer).from_pretrained( embed_model_id, revision=embed_revision, # nosec: B615 - pinned commit local_files_only=local_files_only, ) _model = AutoModel.from_pretrained( embed_model_id, revision=embed_revision, # nosec: B615 - pinned commit local_files_only=local_files_only, ) self.emb_model = cast(PreTrainedModel, _model) cast(Any, self.emb_model).to(device) cast(Any, self.emb_model).eval() self.device = device self.max_length = int(max_length) self.batch_size = int(batch_size) self.l2_normalize = bool(l2_normalize) cfg = getattr(self.emb_model, "config", None) hidden_size = getattr(cfg, "hidden_size", None) self._hidden_size = int(hidden_size) if isinstance(hidden_size, int) else 768 def __call__(self, responses: list[ModelResponse]) -> list[Tensor]: """ Compute **per-token** embeddings for each response. Returns: list[Tensor]: For each response, a tensor of shape **[T, hidden]** (T = tokens in response). """ result: list[Tensor] = [] for resp in responses: token_ids: Tensor = resp.generated_text_tokens # [T] if token_ids.numel() == 0: result.append(torch.empty(0, self._hidden_size, dtype=torch.float32, device=self.device)) continue # Decode each token id to its text piece (keep specials to preserve alignment) pieces: List[str] = [] for tid in token_ids.tolist(): piece = self.tokenizer_decode.decode([int(tid)], skip_special_tokens=False) if isinstance(piece, list): pieces.append("".join(piece)) else: pieces.append(piece) emb = self._embed_texts(pieces) # [T, hidden] result.append(emb) return result @torch.inference_mode() def _embed_texts(self, texts: Sequence[str]) -> Tensor: """Embed a list of short texts -> [len(texts), hidden].""" vecs: list[Tensor] = [] for i in range(0, len(texts), self.batch_size): batch = list(texts[i: i + self.batch_size]) inputs = self.emb_tokenizer( batch, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt", ) inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.emb_model(**inputs) last_hidden: Tensor = outputs.last_hidden_state # [B, L, H] attn: Tensor = inputs["attention_mask"].unsqueeze(-1).type_as(last_hidden) # [B, L, 1] # Mean pool over non-padding tokens summed: Tensor = (last_hidden * attn).sum(dim=1) # [B, H] lengths: Tensor = attn.sum(dim=1).clamp(min=1.0) # [B, 1] mean: Tensor = summed / lengths # [B, H] if self.l2_normalize: mean = torch.nn.functional.normalize(mean, p=2, dim=1) vecs.append(mean) return torch.cat(vecs, dim=0)