Source code for mllm_shap.utils.audio

"""Utility functions for audio processing and display."""

from io import BytesIO
from typing import TYPE_CHECKING

import numpy as np
import soundfile as sf
import torch
import torchaudio.transforms as T
from pydub import AudioSegment
from torch import Tensor

if TYPE_CHECKING:
    from IPython.display import Audio

TARGET_SAMPLE_RATE = 24_000
PCM16_SAMPLE_WIDTH_BYTES = 2
WAVEFORM_2D_DIM = 2
MONO_CHANNELS = 1
AUDIO_FORMAT_WAV = "wav"
AUDIO_FORMAT_MP3 = "mp3"


[docs] def display_audio(audio_content: bytes) -> "Audio": """ Display audio content in a Jupyter notebook. Args: audio_content: The audio content in bytes. """ # Import here to avoid dependency if not used in notebook from IPython.display import Audio # pylint: disable=import-outside-toplevel return Audio(data=audio_content, autoplay=True) # type: ignore
[docs] class TorchAudioHandler: """Utility class for handling audio content with TorchAudio."""
[docs] @staticmethod def from_bytes( audio_content: bytes, audio_format: str = "mp3" ) -> tuple[Tensor, int]: """ Prepare audio content for processing. Args: audio_format: The format of the audio content (default is "mp3"). audio_content: The audio content in bytes. Returns: A tuple containing the audio tensor and the sample rate. """ try: waveform_np, sample_rate = sf.read(BytesIO(audio_content)) waveform = torch.from_numpy(waveform_np).float() if waveform.dim() == MONO_CHANNELS: waveform = waveform.unsqueeze(0) elif waveform.dim() == WAVEFORM_2D_DIM: waveform = waveform.T if waveform.shape[0] > MONO_CHANNELS: waveform = waveform.mean(dim=0, keepdim=True) return waveform, int(sample_rate) except Exception as e: print(f"Error loading with soundfile: {e}, for format: {audio_format}.") raise
[docs] @staticmethod def to_bytes( waveform: torch.Tensor, sample_rate: int = TARGET_SAMPLE_RATE, audio_format: str = "wav", mp3_bitrate: str = "192k", ) -> bytes: """ Convert a waveform tensor to audio content in bytes. Args: waveform: The audio waveform tensor. sample_rate: The sample rate of the audio (default is TARGET_SAMPLE_RATE). audio_format: The desired audio format ("wav" or "mp3"). mp3_bitrate: The bitrate for MP3 encoding (default is "192k"). Returns: The audio content in bytes. """ with torch.no_grad(): wf = waveform.detach().cpu() # force mono if wf.dim() == WAVEFORM_2D_DIM and wf.size(0) > 1: wf = wf.mean(dim=0, keepdim=True) if wf.dim() == WAVEFORM_2D_DIM and wf.size(0) == 1: wf = wf.squeeze(0) elif wf.dim() == 1: pass else: wf = wf.mean(dim=0) wf = wf.to(torch.float32) # replace NaN/Inf, then clamp wf = torch.nan_to_num(wf, nan=0.0, posinf=0.0, neginf=0.0).clamp_(-1.0, 1.0) wf = wf.contiguous() fmt = audio_format.lower() if fmt == AUDIO_FORMAT_WAV: buf = BytesIO() sf.write(buf, wf.numpy(), int(sample_rate), format="WAV", subtype="PCM_16") buf.seek(0) return buf.read() if fmt == AUDIO_FORMAT_MP3: # requires ffmpeg on PATH; pydub delegates to ffmpeg for MP3 encoding # (if ffmpeg is missing will get an error from pydub) pcm16 = (wf.numpy() * 32767.0).astype(np.int16) seg = AudioSegment( pcm16.tobytes(), frame_rate=int(sample_rate), sample_width=PCM16_SAMPLE_WIDTH_BYTES, channels=MONO_CHANNELS, ) buf = BytesIO() seg.export(buf, format=AUDIO_FORMAT_MP3, bitrate=mp3_bitrate) return buf.getvalue() raise ValueError(f"Unsupported audio_format: {audio_format!r}")
[docs] @staticmethod def combine( audio_segments: list[AudioSegment], target_audio_format: str = AUDIO_FORMAT_WAV ) -> bytes: """ Combine multiple AudioSegment instances into a single waveform tensor. Args: audio_segments: A list of AudioSegment instances. target_audio_format: The desired audio format for the output (default is "wav"). Returns: A bytes object containing the combined audio data. """ waveforms: list[Tensor] = [] sample_rates: list[int] = [] for segment in audio_segments: seg_waveform, seg_sr = TorchAudioHandler.from_bytes( segment.audio, audio_format=segment.audio_format ) waveforms.append(seg_waveform) sample_rates.append(seg_sr) if not waveforms: return b"" # Resample if necessary and concatenate target_sr = sample_rates[0] resampled_waveforms: list[Tensor] = [] for wf, sr in zip(waveforms, sample_rates): if sr != target_sr: resampler = T.Resample(orig_freq=sr, new_freq=target_sr) wf = resampler(wf) resampled_waveforms.append(wf) combined_waveform = torch.cat(resampled_waveforms, dim=1) return TorchAudioHandler.to_bytes( combined_waveform, sample_rate=target_sr, audio_format=target_audio_format )