Source code for mllm_shap.connectors.base.audio

"""Spectrogram-Guided Forced Aligner using Wav2Vec2."""

import io
import unicodedata
import warnings
import wave
from dataclasses import dataclass, field
from logging import Logger
from typing import Any, cast

import librosa
import numpy as np
import torch
import torchaudio
from torchaudio.functional import forced_align
from transformers import PreTrainedTokenizerBase, Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import logging as hf_logging

from ...utils.audio import TorchAudioHandler
from ...utils.logger import get_logger

hf_logging.set_verbosity_error()  # type: ignore[no-untyped-call]

logger: Logger = get_logger(__name__)

warnings.filterwarnings("ignore", message=".*forced_align has been deprecated.*")


UNICODE_CATEGORY_NONSPACING_MARK = "Mn"
ASCII_SPACE = " "


@dataclass
class AudioSegment:  # pylint: disable=too-many-instance-attributes
    """Represents a segment of audio aligned to a token/word."""

    token: str
    """The token/word associated with this segment."""

    start_time: float
    """Start time in seconds."""

    end_time: float
    """End time in seconds."""

    confidence: float
    """Confidence score of the alignment."""

    audio: bytes = field(default=b"")
    """Raw audio bytes for this segment."""

    audio_format: str = field(default="wav")
    """Audio format of the raw audio bytes."""

    sample_rate: int | None = field(default=None, repr=False)
    """Sample rate for `start_sample`/`end_sample` (if set)."""

    start_sample: int | None = field(default=None, repr=False)
    """Start sample index in the source waveform (if set)."""

    end_sample: int | None = field(default=None, repr=False)
    """End sample index in the source waveform (if set)."""

    @property
    def duration(self) -> float:
        """Duration of the segment in seconds."""
        return self.end_time - self.start_time

    def __repr__(self) -> str:
        """String representation of the AudioSegment."""
        return (
            f"AudioSegment(token='{self.token}', start={self.start_time:.3f}, "
            f"end={self.end_time:.3f}, dur={self.duration:.3f}s)"
        )

    def __add__(self, other: "AudioSegment") -> "AudioSegment":
        """Combine two AudioSegments into one."""
        if self.token != other.token:
            raise ValueError("Cannot combine AudioSegments with different tokens.")
        combined_audio = self.audio + other.audio

        return AudioSegment(
            token=self.token,
            start_time=min(self.start_time, other.start_time),
            end_time=max(self.end_time, other.end_time),
            confidence=(self.confidence + other.confidence) / 2,
            audio=combined_audio,
            audio_format=self.audio_format,
        )


# pylint: disable=too-few-public-methods,too-many-instance-attributes
[docs] class SpectrogramGuidedAligner: """ Spectrogram-Guided Forced Aligner using Wav2Vec2 and Torchaudio. It takes raw audio bytes and a transcript (string or list of tokens) and produces time-aligned segments with refined boundaries. The alignment process consists of several phases: 1. **Acoustic Modeling:** A CTC model (Wav2Vec2) maps audio to character probabilities. 2. **Forced Alignment:** Dynamic programming finds the optimal alignment path. 3. **Boundary Refinement:** Spectrogram features (Energy & Flux) refine boundaries. 4. **Aggregation:** Character-level segments are grouped into user-defined tokens. """ # pylint: disable=too-many-arguments, too-many-positional-arguments def __init__( self, device: torch.device, model_name: str = "facebook/wav2vec2-large-960h", model_revision: str = "main", sample_rate: int = 16000, ctc_separator: str = "|", ): """ Initializes the aligner with the specified CTC model. Args: device: Torch device to run the model on (CPU or GPU). model_name: Hugging Face model name for the Wav2Vec2 CTC model. model_revision: Model revision or version to use. sample_rate: Expected sample rate for the model (default 16kHz). ctc_separator: Separator used in CTC decoding. """ self.device = device self.sample_rate = sample_rate self.ctc_separator = ctc_separator logger.debug("Loading alignment model: %s on %s...", model_name, device) try: self.processor = cast(Any, Wav2Vec2Processor).from_pretrained( # nosec B615 model_name, revision=model_revision ) model = Wav2Vec2ForCTC.from_pretrained( # nosec B615 model_name, revision=model_revision ) self.model = cast(Wav2Vec2ForCTC, model.to(cast(Any, device))) except OSError as e: raise ValueError( f"Could not load '{model_name}'. Ensure it is a valid CTC model." ) from e self.tokenizer = cast( PreTrainedTokenizerBase, getattr(self.processor, "tokenizer") ) self.vocab = self.tokenizer.get_vocab() self.blank_id = self.tokenizer.pad_token_id or 0
[docs] @staticmethod def normalize_text(text: str) -> str: """ Normalize text for alignment by stripping diacritics and keeping only alphanumeric characters. This normalization is applied consistently to both the transcript used for forced alignment and the target segments used for aggregation to ensure matching. Args: text: The text to normalize. Returns: Normalized text (uppercase, no diacritics, only alphanumeric). """ # Decompose Unicode characters (e.g., "ñ" -> "n" + combining tilde) # and filter out combining marks (diacritics) text_nfd = unicodedata.normalize("NFD", text) text_no_diacritics = "".join( char for char in text_nfd if unicodedata.category(char) != UNICODE_CATEGORY_NONSPACING_MARK ) # Keep only alphanumeric characters and convert to uppercase return "".join(filter(str.isalnum, text_no_diacritics)).upper()
def __compute_emissions( self, waveform: torch.Tensor, original_sr: int ) -> torch.Tensor: """ Computes log-probability emissions from the acoustic model. Args: waveform: Audio waveform tensor. original_sr: Original sampling rate of the waveform. Returns: Emission log-probabilities tensor. """ if original_sr != self.sample_rate: resampler = torchaudio.transforms.Resample( orig_freq=original_sr, new_freq=self.sample_rate ).to(self.device) waveform = resampler(waveform.to(self.device)) else: waveform = waveform.to(self.device) if waveform.dim() > 1: waveform = waveform.squeeze() # Model Forward Pass inputs = self.processor( waveform, sampling_rate=self.sample_rate, return_tensors="pt", padding=True ) with torch.inference_mode(): logits = self.model(inputs.input_values.to(self.device)).logits emissions = torch.log_softmax(logits, dim=-1) return emissions # pylint: disable=too-many-locals def __refine_boundary_smart( self, waveform: np.ndarray, sr: int, candidate_time: float ) -> float: """ 'Smart' Refinement using Energy and Spectral Flux. Args: waveform: Audio waveform as a numpy array. sr: Sampling rate of the waveform. candidate_time: Initial candidate boundary time in seconds. Returns: Refined boundary time in seconds. """ window_samples = int(0.08 * sr) center_sample = int(candidate_time * sr) start_idx = max(0, center_sample - window_samples) end_idx = min(len(waveform), center_sample + window_samples) search_region = waveform[start_idx:end_idx] # Too short to analyze if len(search_region) < 256: # pylint: disable=magic-value-comparison return candidate_time # Compute RMS Energy (Loudness) rms = librosa.feature.rms(y=search_region, frame_length=256, hop_length=64)[0] # Compute Spectral Flux (Change) stft = np.abs(librosa.stft(search_region, n_fft=256, hop_length=64)) flux = np.sum(np.diff(stft, axis=1) ** 2, axis=0) # Pad flux to match rms length flux = np.pad(flux, (0, len(rms) - len(flux)), mode="constant") # Normalize to [0, 1] rms_norm = (rms - np.min(rms)) / (np.max(rms) - np.min(rms) + 1e-9) flux_norm = (flux - np.min(flux)) / (np.max(flux) - np.min(flux) + 1e-9) # Weighted cost function: prioritize silence (low RMS) and stability (low Flux) cost = 0.8 * rms_norm + 0.2 * flux_norm min_idx = np.argmin(cost) refined_sample = start_idx + (min_idx * 64) return refined_sample / sr def __save_wav_mem(self, tensor: torch.Tensor, sample_rate: int) -> bytes: """ Saves a waveform tensor to WAV format in memory. Args: tensor: Audio waveform tensor. sample_rate: Sampling rate for the WAV file. Returns: Raw WAV bytes. """ src = tensor.cpu() if src.dim() == 1: src = src.unsqueeze(0) n_channels = src.shape[0] # Convert float32 to int16 PCM src = (src * 32767).clamp(-32768, 32767).to(torch.int16) src = src.t().numpy() # type: ignore[assignment] buffer = io.BytesIO() with wave.open(buffer, "wb") as wav_file: wav_file.setnchannels(n_channels) # pylint: disable=no-member wav_file.setsampwidth(2) # pylint: disable=no-member wav_file.setframerate(sample_rate) # pylint: disable=no-member wav_file.writeframes(src.tobytes()) # type: ignore[attr-defined] # pylint: disable=no-member return buffer.getvalue() def __merge_tokens( self, alignment_path: torch.Tensor, blank_id: int ) -> list[tuple[int, int, int]]: """ Merges frame-level alignment into (token, start, end) spans. Args: alignment_path: Tensor of token IDs aligned per frame. blank_id: ID of the blank token in CTC. Returns: list of (token_id, start_frame, end_frame) tuples. """ path = alignment_path.tolist() spans = [] current_token = None start_frame = 0 for i, token in enumerate(path): if token != current_token: if current_token is not None and current_token != blank_id: spans.append((current_token, start_frame, i)) current_token = token start_frame = i # Final span if current_token is not None and current_token != blank_id: spans.append((current_token, start_frame, len(path))) return spans def __prepare_transcript( self, transcript: str | list[str] ) -> tuple[str, list[str], str, list[int]]: """ Prepares the transcript for alignment. Args: transcript: Either a single string or a list of tokens. Returns: tuple containing full transcript, target segments, clean text, and valid token IDs. """ if isinstance(transcript, str): full_transcript = transcript target_segments = transcript.split() else: target_segments = transcript full_transcript = " ".join(transcript) # Normalize text: strip diacritics, uppercase, keep alphanumeric + spaces # Then replace spaces with CTC separator text_upper = full_transcript.upper() # Strip diacritics while preserving spaces text_nfd = unicodedata.normalize("NFD", text_upper) text_no_diacritics = "".join( char for char in text_nfd if unicodedata.category(char) != UNICODE_CATEGORY_NONSPACING_MARK # Remove combining marks (diacritics) ) # Keep only alphanumeric and spaces text_clean = "".join( c for c in text_no_diacritics if c.isalnum() or c == ASCII_SPACE ) # Replace spaces with CTC separator clean_text = text_clean.replace(ASCII_SPACE, self.ctc_separator) valid_tokens = [ cast(int, self.tokenizer.convert_tokens_to_ids(c)) for c in clean_text if c in self.vocab ] if not valid_tokens: raise ValueError("Transcript contains no valid characters for this model.") return full_transcript, target_segments, clean_text, valid_tokens def __perform_forced_alignment( self, waveform: torch.Tensor, original_sr: int, valid_tokens: list[int] ) -> tuple[torch.Tensor, torch.Tensor]: """ Performs forced alignment using Torchaudio's forced_align. Ensures CPU fallback for compatibility. Returns: tuple of (alignment_path, emission_log_probs). """ emissions_gpu = self.__compute_emissions(waveform, original_sr).squeeze(0) # Move to CPU for forced_align to support MPS (Apple Silicon) emissions_cpu = emissions_gpu.unsqueeze(0).cpu() targets_cpu = torch.tensor([valid_tokens], dtype=torch.int32).cpu() emission_lens_cpu = torch.tensor([emissions_gpu.size(0)]).cpu() target_lens_cpu = torch.tensor([len(valid_tokens)]).cpu() aligned_tokens, _ = forced_align( emissions_cpu, targets_cpu, emission_lens_cpu, target_lens_cpu, blank=self.blank_id, ) alignment_path = aligned_tokens[0] return alignment_path, emissions_gpu # pylint: disable=too-many-locals def __refine_token_spans( self, token_spans: list[tuple[int, int, int]], emissions_gpu: torch.Tensor, waveform: torch.Tensor, original_sr: int, ) -> list[dict[str, str | float]]: """ Refines token boundary times using acoustic features. Args: token_spans: list of (token_id, start_frame, end_frame) tuples. emissions_gpu: Emission log-probabilities tensor. waveform: Audio waveform tensor. original_sr: Original sampling rate of the waveform. Returns: list of refined character segments with start/end times and confidence. """ ratio = waveform.size(1) / emissions_gpu.size(0) numpy_wave = waveform.cpu().numpy().squeeze() refined_chars: list[dict[str, str | float]] = [] for sp_token, sp_start, sp_end in token_spans: t_start = (sp_start * ratio) / original_sr t_end = (sp_end * ratio) / original_sr # Calculate confidence from GPU emissions (fast) conf = torch.exp(emissions_gpu[sp_start:sp_end, sp_token]).mean().item() # Refine boundaries r_start = self.__refine_boundary_smart(numpy_wave, original_sr, t_start) r_end = self.__refine_boundary_smart(numpy_wave, original_sr, t_end) char = cast(str, self.tokenizer.convert_ids_to_tokens(sp_token)) refined_chars.append( {"char": char, "start": r_start, "end": r_end, "confidence": conf} ) return refined_chars def __aggregate_chars_to_segments( self, char_segments: list[dict[str, str | float]], target_segments: list[str], ) -> list[AudioSegment]: """ Aggregates character-level alignment into the provided target segments. Args: char_segments: list of character-level segments with timings. target_segments: list of target tokens/words to align to. Returns: list of aggregated AudioSegment objects. """ # Aggregate chars to words/tokens final_segments = [] current_char_idx = 0 for segment_text in target_segments: # Normalize target segment for matching - use the same normalization as __prepare_transcript clean_target = self.normalize_text(segment_text) if not clean_target: continue start_time = None end_time = None confs = [] found_chars = 0 # Greedy matching while found_chars < len(clean_target) and current_char_idx < len( char_segments ): seg = char_segments[current_char_idx] seg_char = cast(str, seg["char"]).replace(self.ctc_separator, "") if seg_char == clean_target[found_chars]: if start_time is None: start_time = seg["start"] end_time = seg["end"] confs.append(seg["confidence"]) found_chars += 1 current_char_idx += 1 if start_time is not None: # Fallback for single characters if end_time is None: end_time = cast(float, start_time) + 0.1 avg_conf = sum(confs) / len(confs) if confs else 0.0 # type: ignore final_segments.append( AudioSegment( token=segment_text, start_time=cast(float, start_time), end_time=cast(float, end_time), confidence=avg_conf, audio_format="wav", ) ) return final_segments def __set_segment_indices( self, final_segments: list[AudioSegment], waveform: torch.Tensor, original_sr: int, ) -> torch.Tensor: """ Attach sample indices to segments based on timing information. Args: final_segments: list of AudioSegment objects. waveform: Audio waveform tensor. original_sr: Original sampling rate of the waveform. Returns: CPU waveform tensor in shape [channels, samples]. """ cpu_waveform = waveform.cpu() if cpu_waveform.dim() == 1: cpu_waveform = cpu_waveform.unsqueeze(0) for seg in final_segments: start_sample = int(seg.start_time * original_sr) end_sample = int(seg.end_time * original_sr) # Ensure minimum duration (50ms) to avoid degenerate files min_duration = int(0.05 * original_sr) if end_sample - start_sample < min_duration: end_sample = start_sample + min_duration # Clamp boundaries start_sample = max(0, start_sample) end_sample = min(cpu_waveform.size(1), end_sample) seg.sample_rate = original_sr seg.start_sample = start_sample seg.end_sample = end_sample return cpu_waveform def __attach_audio_to_segments( self, final_segments: list[AudioSegment], waveform: torch.Tensor, original_sr: int, attach_audio: bool = True, ) -> None: """ Attaches segment indices (always) and optionally raw audio bytes. Args: final_segments: list of AudioSegment objects. waveform: Audio waveform tensor. original_sr: Original sampling rate of the waveform. """ cpu_waveform = self.__set_segment_indices( final_segments, waveform, original_sr ) if not attach_audio: return for seg in final_segments: segment_tensor = cpu_waveform[:, seg.start_sample : seg.end_sample] # noqa: E203 seg.audio = self.__save_wav_mem(segment_tensor, original_sr)
[docs] def attach_audio_to_segments( self, segments: list[AudioSegment], audio_content: bytes | None = None, waveform: torch.Tensor | None = None, original_sr: int | None = None, audio_format: str = "mp3", ) -> None: """ Attach raw audio bytes to existing AudioSegment objects. This is a helper method to materialize audio bytes for segments that were created without audio attachment (i.e., with attach_audio=False). Args: segments: list of AudioSegment objects to attach audio to. audio_content: Raw audio bytes. Either this or (waveform, original_sr) must be provided. waveform: Audio waveform tensor. Either this and original_sr or audio_content must be provided. original_sr: Original sampling rate of the waveform. audio_format: Format of the audio content (default is "mp3"). Raises: ValueError: If neither audio_content nor (waveform, original_sr) are provided. """ if audio_content is None and (waveform is None or original_sr is None): raise ValueError( "Either audio_content or both waveform and original_sr must be provided." ) if audio_content is not None: waveform, original_sr = TorchAudioHandler.from_bytes( audio_content, audio_format=audio_format ) original_sr = cast(int, original_sr) waveform = cast(torch.Tensor, waveform) # Materialize bytes on demand. self.__attach_audio_to_segments( segments, waveform, original_sr, attach_audio=True )
def __call__( self, transcript: str | list[str], audio_content: bytes | None = None, waveform: torch.Tensor | None = None, original_sr: int | None = None, audio_format: str = "mp3", attach_audio: bool = False, ) -> list[AudioSegment]: """ Main pipeline execution. Args: audio_content: Raw audio bytes. transcript: Either a single string (will be split by spaces) OR a list of tokens/utterances to align to. audio_format: Format of the audio content (default is "mp3"). attach_audio: Whether to attach raw audio bytes to each segment (default is False). If False, segments will have empty audio bytes to save memory. Returns: list of aligned AudioSegment objects. """ if audio_content is None and (waveform is None or original_sr is None): raise ValueError( "Either audio_content or both waveform and original_sr must be provided." ) if audio_content is not None: waveform, original_sr = TorchAudioHandler.from_bytes( audio_content, audio_format=audio_format ) original_sr = cast(int, original_sr) waveform = cast(torch.Tensor, waveform) _, target_segments, clean_text, valid_tokens = self.__prepare_transcript( transcript ) logger.debug("Aligning to transcript: '%s'", clean_text) alignment_path, emissions_gpu = self.__perform_forced_alignment( waveform, original_sr, valid_tokens ) token_spans = self.__merge_tokens(alignment_path, self.blank_id) refined_chars = self.__refine_token_spans( token_spans, emissions_gpu, waveform, original_sr ) final_segments = self.__aggregate_chars_to_segments( refined_chars, target_segments ) if attach_audio: self.__attach_audio_to_segments( final_segments, waveform, original_sr ) else: self.__set_segment_indices(final_segments, waveform, original_sr) return final_segments