Source code for mllm_shap.shap.masks.mask_space

"""Mask-space utilities for explainable feature indexing."""

from dataclasses import dataclass

import torch
from torch import Tensor


[docs] @dataclass(frozen=True) class MaskSpace: """Describes explainable feature subset inside full token mask.""" shap_values_mask: Tensor target_length: int @property def n_features(self) -> int: """Number of explainable features.""" return int(self.shap_values_mask.sum().item())
[docs] def materialize(self, split: Tensor, device: torch.device) -> Tensor: """Project split over explainable subset back to full chat mask.""" prepared = torch.zeros((self.target_length,), dtype=torch.bool, device=device) prepared[self.shap_values_mask] = split prepared[~self.shap_values_mask] = True return prepared