Source code for mllm_shap.shap.normalizers

# pylint: disable=too-few-public-methods

"""Normalizers for SHAP values."""

import torch
from torch import Tensor

from .base.normalizers import BaseNormalizer


[docs] class IdentityNormalizer(BaseNormalizer): """Normalizer that returns SHAP values unchanged.""" def __call__(self, shap_values: Tensor) -> Tensor: return shap_values
[docs] class AbsSumNormalizer(BaseNormalizer): """Normalizer that scales SHAP values by the sum of their absolute values.""" def __call__(self, shap_values: Tensor) -> Tensor: abs_sum = shap_values.abs().sum() if abs_sum == 0: return shap_values return shap_values / abs_sum
[docs] class PowerShiftNormalizer(BaseNormalizer): """Normalizer that applies power shift normalization to SHAP values.""" power: float """The power to which SHAP values are raised.""" def __init__(self, power: float = 1.0): """ Initialize the PowerShiftNormalizer. Args: power: The power to which SHAP values are raised. Raises: ValueError: If power is not a positive float. """ if isinstance(power, (int, float)) and power <= 0: raise ValueError("power must be a positive float.") self.power = power def __call__(self, shap_values: Tensor) -> Tensor: # Shift SHAP values to be non-negative shifted = shap_values - shap_values.min() # Apply power transformation powered = shifted.pow(self.power) # Normalize to sum to 1 total = powered.sum() if total == 0: return shap_values normalized = powered / total return normalized
[docs] class MinMaxNormalizer(BaseNormalizer): """ Normalizer that scales SHAP values to the [0, 1] range using min-max normalization then normalizes them to sum to 1. """ def __call__(self, shap_values: Tensor) -> Tensor: min_val = shap_values.min() max_val = shap_values.max() if max_val - min_val == 0: # all values are the same return torch.ones_like(shap_values) / len(shap_values) normalized = (shap_values - min_val) / (max_val - min_val) normalized /= normalized.sum() return normalized