Source code for mllm_shap.shap.masks.mask_space
"""Mask-space utilities for explainable feature indexing."""
from dataclasses import dataclass
import torch
from torch import Tensor
[docs]
@dataclass(frozen=True)
class MaskSpace:
"""Describes explainable feature subset inside full token mask."""
shap_values_mask: Tensor
target_length: int
@property
def n_features(self) -> int:
"""Number of explainable features."""
return int(self.shap_values_mask.sum().item())
[docs]
def materialize(self, split: Tensor, device: torch.device) -> Tensor:
"""Project split over explainable subset back to full chat mask."""
prepared = torch.zeros((self.target_length,), dtype=torch.bool, device=device)
prepared[self.shap_values_mask] = split
prepared[~self.shap_values_mask] = True
return prepared