Source code for mllm_shap.shap.masks.mask_codec

"""Bitset-oriented mask encoding and hashing utilities."""

from dataclasses import dataclass

import numpy as np
import torch
from torch import Tensor


[docs] @dataclass(frozen=True) class PackedMask: """Packed bit representation of one boolean mask.""" words: bytes n_bits: int
[docs] class MaskCodec: """Encode/decode boolean masks to packed bytes."""
[docs] @staticmethod def normalize(mask: Tensor) -> Tensor: """Normalize incoming mask to 1D bool tensor.""" if mask.ndim > 1: if mask.shape[0] != 1: raise ValueError("Mask must be 1D or have exactly one row.") mask = mask.squeeze(0) return mask.contiguous().bool()
[docs] @staticmethod def pack(mask: Tensor) -> PackedMask: """Pack mask into little-endian bytes.""" normalized = MaskCodec.normalize(mask) n_bits = int(normalized.numel()) cpu_bits = normalized.to(device="cpu").numpy().astype(np.uint8) packed = np.packbits(cpu_bits, bitorder="little") return PackedMask(words=packed.tobytes(), n_bits=n_bits)
[docs] @staticmethod def unpack(packed: PackedMask, device: torch.device | None = None) -> Tensor: """Unpack bytes back to 1D bool tensor.""" raw = np.frombuffer(packed.words, dtype=np.uint8) bits = np.unpackbits(raw, bitorder="little")[: packed.n_bits].copy() return torch.from_numpy(bits.astype(np.bool_)).to(device=device)
[docs] @staticmethod def hash(mask: Tensor) -> int: """Stable hash over packed mask bytes.""" packed = MaskCodec.pack(mask) return hash((packed.words, packed.n_bits))