Source code for mllm_shap.utils.jupyter
"""Utility functions for Jupyter Notebook visualization."""
from typing import Any
import pandas as pd
import torch
from pandas.io.formats.style import Styler
from torch import Tensor
from .audio import display_audio
[docs]
def audio_html(content: bytes) -> str:
"""
Generate HTML representation for audio content.
Args:
content: Audio content in bytes.
Returns:
str: HTML representation of the audio.
"""
a = display_audio(content)
return str(a._repr_html_()) # type: ignore[no-untyped-call] # pylint: disable=protected-access
[docs]
def display_shap_colors_df(
df: pd.DataFrame,
shap_column_name: str = "Shapley Value",
cmap: str = "coolwarm",
low: float = 0.0,
high: float = 1.0,
**kwargs: Any,
) -> Styler:
"""Set background gradient colors for SHAP values in a DataFrame.
Args:
df: DataFrame containing SHAP values.
shap_column_name: Name of the column with SHAP values.
cmap: Colormap to use for the gradient.
low: Minimum value for the gradient.
high: Maximum value for the gradient.
**kwargs: Additional arguments for pandas Styler.background_gradient.
Returns:
pd.Styler: Styled DataFrame with background gradient.
"""
return df.style.background_gradient(subset=[shap_column_name], cmap=cmap, low=low, high=high, **kwargs)
[docs]
def display_shap_colors_df_audio(df: pd.DataFrame, audio_column_name: str = "Audio", **kwargs: Any) -> Styler:
"""
Set background gradient colors for SHAP values in a DataFrame with audio.
Render audio in the specified audio column for jupyter notebooks.
Args:
df: DataFrame containing SHAP values and audio.
audio_column_name: Name of the column with audio.
**kwargs: Additional arguments for display_shap_colors_df.
Returns:
pd.Styler: Styled DataFrame with background gradient.
"""
df[audio_column_name] = df[audio_column_name].apply(audio_html)
return display_shap_colors_df(df, **kwargs)
[docs]
def plot_distribution(values: Tensor, bins: int = 50, **kwargs: Any) -> None:
"""
Plot histogram of SHAP values distribution.
Args:
values: Tensor of SHAP values.
bins: Number of bins for the histogram.
**kwargs: Additional arguments for matplotlib.pyplot.hist.
"""
import matplotlib.pyplot as plt # pylint: disable=import-outside-toplevel
# Move to CPU & flatten for plotting
values_np = values.detach().cpu().to(torch.float32).numpy().flatten()
# Plot histogram
plt.hist(values_np, bins=bins, **kwargs)
plt.title("SHAP Values Distribution")
plt.xlabel("SHAP Value")
plt.ylabel("Frequency")
plt.show()