Source code for mllm_shap.benchmarks.bench_api_perf

"""Micro-benchmarks for API/performance hotspots."""

from __future__ import annotations

import argparse
import statistics
import time
from dataclasses import dataclass
from typing import Any, Iterable

import torch
from torch import Tensor

from ..connectors.base.model_response import ModelResponse
from ..shap.base._generate_responses import generate_responses
from ..shap.base._masks_manager import MasksManager
from ..shap.base.approx import BaseShapApproximation


def _time_many(fn: Any, repeats: int) -> tuple[float, float, float]:
    times: list[float] = []
    for _ in range(repeats):
        t0 = time.perf_counter()
        fn()
        times.append(time.perf_counter() - t0)
    return min(times), statistics.median(times), max(times)


def _fmt(label: str, values: tuple[float, float, float], unit: str = "ms") -> str:
    mul = 1000.0 if unit == "ms" else 1.0
    lo, med, hi = values
    return f"{label:<36} min={lo * mul:8.3f} {unit}  p50={med * mul:8.3f} {unit}  max={hi * mul:8.3f} {unit}"


[docs] def bench_mask_hash(iters: int, mask_len: int, repeats: int) -> None: """Benchmark current mask hash implementation.""" mask = torch.randint(0, 2, (mask_len,), dtype=torch.bool) def run() -> None: for _ in range(iters): _ = MasksManager.get_hash(mask) values = _time_many(run, repeats=repeats) print("== mask hash ==") print(_fmt(f"{iters}x hash(len={mask_len})", values))
@dataclass class _FakeChat: cache: Any = None @classmethod def from_chat(cls, mask: Tensor, chat: "_FakeChat") -> "_FakeChat": del mask, chat return cls() class _FakeModel: def generate( self, *, chat: _FakeChat, keep_history: bool = False, **_: dict[str, Any], ) -> ModelResponse: del chat response_chat = _FakeChat() if keep_history else None return ModelResponse( chat=response_chat, generated_text_tokens=torch.tensor([1, 2, 3], dtype=torch.long), generated_audio_tokens=torch.empty((0, 0), dtype=torch.long), generated_modality_flag=torch.ones((3,), dtype=torch.long), ) class _FakeCacheManager: def __init__(self) -> None: self._cache: dict[int, ModelResponse] = {} def contains(self, mask_hash: int) -> bool: return mask_hash in self._cache def extract(self, mask_hash: int) -> ModelResponse: return self._cache[mask_hash] def _mask_gen(n_masks: int, mask_len: int) -> Iterable[tuple[Tensor, int]]: for i in range(n_masks): mask = torch.randint(0, 2, (mask_len,), dtype=torch.bool) yield mask, hash((i, mask_len))
[docs] def bench_response_generation( n_masks: int, mask_len: int, repeats: int, n_jobs: int, verbose: bool ) -> None: """Benchmark masked response generation orchestration.""" def run() -> None: masks: list[Tensor] = [] responses: list[ModelResponse] = [] _ = generate_responses( masks=masks, responses=responses, gen=_mask_gen(n_masks=n_masks, mask_len=mask_len), source_chat=_FakeChat(), model=_FakeModel(), cache_manager=_FakeCacheManager(), n_generator_jobs=n_jobs, progress_bar=False, verbose=verbose, ) values = _time_many(run, repeats=repeats) print("== response generation ==") print( _fmt( f"{n_masks} masks len={mask_len} jobs={n_jobs} verbose={verbose}", values, ) )
class _DummyApprox(BaseShapApproximation): """Minimal approximation explainer for in-place budget update benchmark.""" def _get_num_splits(self, n: int) -> int: del n return self.num_samples if self.num_samples is not None else 1 def _calculate_shap_values( self, masks: Tensor, similarities: Tensor, device: torch.device ) -> Tensor: del similarities, device return torch.zeros(masks.shape[-1], dtype=torch.float32) def _get_next_split( self, n: int, device: torch.device, generated_masks_num: int, existing_masks: list[Tensor] | None = None, ) -> Tensor | None: del n, device, generated_masks_num, existing_masks return None
[docs] def bench_linear_num_samples_update(iters: int, repeats: int) -> None: """Benchmark in-place sampling budget update path used by experiments runner.""" approx = _DummyApprox(num_samples=100, fraction=0.6) def run() -> None: for i in range(iters): approx.num_samples = 100 + i values = _time_many(run, repeats=repeats) print("== linear num_samples update ==") print(_fmt(f"{iters}x in-place num_samples set", values))
[docs] def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--iters", type=int, default=2000) parser.add_argument("--repeats", type=int, default=5) parser.add_argument("--mask-len", type=int, default=512) parser.add_argument("--n-masks", type=int, default=256) parser.add_argument("--jobs", type=int, default=1) parser.add_argument( "--bench", choices=["all", "mask-hash", "responses", "linear-update"], default="all", ) args = parser.parse_args() if args.bench in ("all", "mask-hash"): bench_mask_hash(iters=args.iters, mask_len=args.mask_len, repeats=args.repeats) print() if args.bench in ("all", "responses"): bench_response_generation( n_masks=args.n_masks, mask_len=args.mask_len, repeats=args.repeats, n_jobs=args.jobs, verbose=False, ) print() if args.bench in ("all", "linear-update"): bench_linear_num_samples_update(iters=args.iters, repeats=args.repeats)
if __name__ == "__main__": main()