Source code for mllm_shap.utils.other

"""General utility functions."""

from typing import Any, Callable

import torch
from torch import Tensor


[docs] def raise_connector_error(callable_: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: """ Wrapper to raise connector errors with more context. Args: callable_: The callable to wrap. *args: Positional arguments for the callable. **kwargs: Keyword arguments for the callable. Returns: The result of the callable. Raises: RuntimeError: If an error occurs in the callable. """ try: return callable_(*args, **kwargs) except Exception as e: raise RuntimeError("Error occurred in connector implementation.") from e
[docs] def safe_mask(tensor: Tensor, mask: Tensor) -> Tensor: """ Mask the tensor with the given mask. If mask is empty, return empty Tensor while maintaining the original tensor properties. Args: tensor: The input tensor to be masked. mask: The boolean mask tensor. Returns: The masked tensor. """ masked = tensor[..., mask] if masked.numel() == 0: target_shape = (tensor.shape[0], 0) if len(tensor.shape) > 1 else (tensor.shape[0],) masked = torch.empty(target_shape, device=tensor.device, dtype=tensor.dtype) return masked
[docs] def safe_mask_unsqueeze(tensor: Tensor, mask: Tensor) -> Tensor: """ Mask the tensor with the given mask. If mask is empty, return empty Tensor while maintaining the original tensor properties, and unsqueeze to maintain batch dimension. Args: tensor: The input tensor to be masked. mask: The boolean mask tensor. Returns: The masked tensor with batch dimension. """ masked = tensor[0][mask] if masked.numel() == 0: target_shape = (tensor.shape[0], 0) if len(tensor.shape) > 1 else (tensor.shape[0],) masked = torch.empty(target_shape, device=tensor.device, dtype=tensor.dtype) else: masked = masked.unsqueeze(0) return masked
[docs] def make_consecutive_ids_ignore_zero(t: torch.Tensor) -> torch.Tensor: """ Renumber non-zero tensor values to consecutive integers starting from 1, preserving the order of first appearance. Zeros remain unchanged. Args: t: Input tensor with integer IDs. Returns: Tensor with renumbered IDs. """ # Get unique consecutive non-zero values in order of appearance nonzero_mask = t != 0 unique_vals = torch.unique_consecutive(t[nonzero_mask]) # Map original IDs to new consecutive ones mapping = {v.item(): i + 1 for i, v in enumerate(unique_vals)} # Apply mapping (zeros unchanged) out = t.clone() for old_id, new_id in mapping.items(): out[t == old_id] = new_id return out
[docs] def extend_tensor(t: Tensor, target_length: int, fill_value: Any) -> Tensor: """ Extend a tensor to the target length by appending the fill value. Args: t: The input tensor to be extended. target_length: The desired length of the output tensor. fill_value: The value to use for extension. Returns: The extended tensor. """ current_length = t.shape[0] if current_length >= target_length: return t extension_size = target_length - current_length extension = torch.full( (extension_size,), fill_value, dtype=t.dtype, device=t.device, ) extended_tensor = torch.cat([t, extension], dim=0) return extended_tensor