Source code for mllm_shap.connectors.base.audio

"""Spectrogram-Guided Forced Aligner using Wav2Vec2."""

import io
import unicodedata
import warnings
import wave
from dataclasses import dataclass, field
from logging import Logger
from typing import Any, cast

import librosa
import numpy as np
import torch
import torchaudio
from torchaudio.functional import forced_align
from transformers import PreTrainedTokenizerBase, Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import logging as hf_logging

from ...utils.audio import TorchAudioHandler
from ...utils.logger import get_logger

hf_logging.set_verbosity_error()

logger: Logger = get_logger(__name__)

warnings.filterwarnings("ignore", message=".*forced_align has been deprecated.*")


ASCII_SPACE: str = " "
"""ASCII space character, used for transcript normalization and CTC blank token replacement."""
UNICODE_CATEGORY_NONSPACING_MARK: str = "Mn"
"""Unicode category for non-spacing marks, used to strip diacritics during transcript normalization."""
_SILENCE_THRESHOLD_RATIO: float = 0.5
"""When refining boundaries, the minimum RMS must be less than this ratio of the mean
 RMS in the search window to be accepted as a valid silence point. Prevents false positives in continuous speech."""
_MIN_REFINE_SAMPLES: int = 256
"""Minimum number of audio samples required in the boundary refinement search region to perform analysis.
Prevents unreliable refinements in very short segments."""


@dataclass
class AudioSegment:
    """Represents a segment of audio aligned to a token/word."""

    token: str
    """The token/word associated with this segment."""

    start_time: float
    """Start time in seconds."""
    end_time: float
    """End time in seconds."""

    confidence: float
    """Confidence score of the alignment."""

    audio: bytes = field(default=b"")
    """Raw audio bytes for this segment."""

    audio_format: str = field(default="wav")
    """Audio format of the raw audio bytes."""
    sample_rate: int | None = field(default=None, repr=False)
    """Sample rate for `start_sample`/`end_sample` (if set)."""

    start_sample: int | None = field(default=None, repr=False)
    """Start sample index in the source waveform (if set)."""
    end_sample: int | None = field(default=None, repr=False)
    """End sample index in the source waveform (if set)."""

    @property
    def duration(self) -> float:
        """Duration of the segment in seconds."""
        return self.end_time - self.start_time

    def __repr__(self) -> str:
        """String representation of the AudioSegment."""
        return (
            f"AudioSegment(token='{self.token}', start={self.start_time:.3f}, "
            f"end={self.end_time:.3f}, dur={self.duration:.3f}s)"
        )

    def __add__(self, other: "AudioSegment") -> "AudioSegment":
        """Combine two AudioSegments into one."""
        if self.token != other.token:
            raise ValueError("Cannot combine AudioSegments with different tokens.")
        combined_audio = self.audio + other.audio

        return AudioSegment(
            token=self.token,
            start_time=min(self.start_time, other.start_time),
            end_time=max(self.end_time, other.end_time),
            confidence=(self.confidence + other.confidence) / 2,
            audio=combined_audio,
            audio_format=self.audio_format,
        )


[docs] class SpectrogramGuidedAligner: """ Spectrogram-Guided Forced Aligner using Wav2Vec2 and Torchaudio. It takes raw audio bytes and a transcript (string or list of tokens) and produces time-aligned segments with refined boundaries. The alignment process consists of several phases: 1. **Acoustic Modeling:** A CTC model (Wav2Vec2) maps audio to character probabilities. 2. **Forced Alignment:** Dynamic programming finds the optimal alignment path. 3. **Boundary Refinement:** Spectrogram features (Energy & Flux) refine boundaries, using gap midpoints as search anchors and joint negotiation to prevent inversions. 4. **Aggregation:** Character-level segments are grouped into user-defined tokens. """ def __init__( self, device: torch.device, model_name: str = "facebook/wav2vec2-large-960h", model_revision: str = "main", sample_rate: int = 16000, ctc_separator: str = "|", boundary_energy_weight: float = 0.8, boundary_flux_weight: float = 0.2, ): """Initialize the aligner with the specified configuration. Args: device: The torch device to run the model on. model_name: The Hugging Face model name for the Wav2Vec2 CTC model. model_revision: The revision of the model to load. sample_rate: The sample rate to use for audio processing. ctc_separator: The character used to replace spaces in the transcript for CTC processing. boundary_energy_weight: Weight for energy in boundary refinement (0 to 1). boundary_flux_weight: Weight for spectral flux in boundary refinement (0 to 1). """ if boundary_energy_weight < 0 or boundary_flux_weight < 0: raise ValueError("Boundary refinement weights must be non-negative.") if boundary_energy_weight + boundary_flux_weight <= 0: raise ValueError( "At least one boundary refinement weight must be greater than zero." ) self.device = device self.sample_rate = sample_rate self.ctc_separator = ctc_separator self.boundary_energy_weight = float(boundary_energy_weight) self.boundary_flux_weight = float(boundary_flux_weight) logger.debug("Loading alignment model: %s on %s...", model_name, device) try: self.processor = cast(Any, Wav2Vec2Processor).from_pretrained( model_name, revision=model_revision ) model = Wav2Vec2ForCTC.from_pretrained(model_name, revision=model_revision) self.model = cast(Wav2Vec2ForCTC, model.to(cast(Any, device))) except OSError as e: raise ValueError( f"Could not load '{model_name}'. Ensure it is a valid CTC model." ) from e self.tokenizer = cast( PreTrainedTokenizerBase, getattr(self.processor, "tokenizer") ) self.vocab = self.tokenizer.get_vocab() self.blank_id = self.tokenizer.pad_token_id or 0
[docs] @staticmethod def normalize_text(text: str) -> str: """Normalize text by removing diacritics, non-alphanumeric characters, and converting to uppercase.""" text_nfd = unicodedata.normalize("NFD", text) text_no_diacritics = "".join( char for char in text_nfd if unicodedata.category(char) != UNICODE_CATEGORY_NONSPACING_MARK ) return "".join(filter(str.isalnum, text_no_diacritics)).upper()
def __compute_emissions( self, waveform: torch.Tensor, original_sr: int ) -> torch.Tensor: """Compute the emission probabilities from the audio waveform using the Wav2Vec2 model.""" if original_sr != self.sample_rate: resampler = torchaudio.transforms.Resample( orig_freq=original_sr, new_freq=self.sample_rate ).to(self.device) waveform = resampler(waveform.to(self.device)) else: waveform = waveform.to(self.device) if waveform.dim() > 1: waveform = waveform.squeeze() inputs = self.processor( waveform, sampling_rate=self.sample_rate, return_tensors="pt", padding=True ) with torch.inference_mode(): logits = self.model(inputs.input_values.to(self.device)).logits emissions = torch.log_softmax(logits, dim=-1) return emissions def __refine_boundary_smart( self, waveform: np.ndarray, sr: int, candidate_time: float, left_time: float | None = None, right_time: float | None = None, ) -> tuple[float, bool]: """ Refine a single boundary timestamp using energy and spectral flux. Fix 1: When left_time and right_time are supplied (the raw endpoints of the gap between two adjacent character spans), the search window is centred on their midpoint so that long silences are searched symmetrically rather than being anchored to one character's tail. Fix 3: If no frame in the window is clearly quieter than the window mean (no genuine silence), the raw candidate_time is returned unchanged rather than forcing a cut inside an active phoneme. Fix 4: Returns a boolean indicating whether refinement was actually applied (False when the search region was too short or no silence was found). Args: waveform: Full audio waveform as a 1-D numpy array. sr: Sampling rate. candidate_time: Raw CTC boundary estimate in seconds. left_time: Start of the blank gap (seconds). Optional. right_time: End of the blank gap (seconds). Optional. Returns: (refined_time_seconds, was_refined) """ # use gap midpoint as centre when gap endpoints are known. if left_time is not None and right_time is not None: center_time = (left_time + right_time) / 2.0 half_window = (right_time - left_time) / 2.0 + 0.04 # gap ± 40 ms margin else: center_time = candidate_time half_window = 0.08 # original ±80 ms window_samples = int(half_window * sr) center_sample = int(center_time * sr) start_idx = max(0, center_sample - window_samples) end_idx = min(len(waveform), center_sample + window_samples) search_region = waveform[start_idx:end_idx] # flag when region is too short to analyse. if len(search_region) < _MIN_REFINE_SAMPLES: return candidate_time, False rms = librosa.feature.rms(y=search_region, frame_length=256, hop_length=64)[0] stft = np.abs(librosa.stft(search_region, n_fft=256, hop_length=64)) flux = np.sum(np.diff(stft, axis=1) ** 2, axis=0) flux = np.pad(flux, (0, len(rms) - len(flux)), mode="constant") rms_norm = (rms - np.min(rms)) / (np.max(rms) - np.min(rms) + 1e-9) flux_norm = (flux - np.min(flux)) / (np.max(flux) - np.min(flux) + 1e-9) cost = ( self.boundary_energy_weight * rms_norm + self.boundary_flux_weight * flux_norm ) min_idx = int(np.argmin(cost)) # only accept the refined position if it is genuinely quieter # than the window mean. If no such frame exists (continuous speech, # no inter-word pause), fall back to the raw candidate. min_rms = rms[min_idx] mean_rms = float(np.mean(rms)) if min_rms > _SILENCE_THRESHOLD_RATIO * mean_rms: return candidate_time, False refined_sample = start_idx + (min_idx * 64) return refined_sample / sr, True def __save_wav_mem(self, tensor: torch.Tensor, sample_rate: int) -> bytes: """Convert a tensor to WAV format and return as bytes.""" src = tensor.cpu() if src.dim() == 1: src = src.unsqueeze(0) n_channels = src.shape[0] src = (src * 32767).clamp(-32768, 32767).to(torch.int16) src = src.t().numpy() buffer = io.BytesIO() with wave.open(buffer, "wb") as wav_file: wav_file.setnchannels(n_channels) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) wav_file.writeframes(src.tobytes()) return buffer.getvalue() def __merge_tokens( self, alignment_path: torch.Tensor, blank_id: int ) -> list[tuple[int, int, int]]: """ Merge frame-level alignment into (token, start, end) spans. Blank-labelled frames are excluded from character spans by construction: a span ends at the frame where any token transition occurs, so blank regions appear as gaps between adjacent character spans. These gaps are resolved in __refine_token_spans using gap-midpoint anchoring (Fix 1 / Fix 2). Returns: list of (token_id, start_frame, end_frame) tuples. Unchanged from the original — blank handling is correct here; the original methodological problem was in how the resulting gap timestamps were used downstream. """ path = alignment_path.tolist() spans = [] current_token = None start_frame = 0 for i, token in enumerate(path): if token != current_token: if current_token is not None and current_token != blank_id: spans.append((current_token, start_frame, i)) current_token = token start_frame = i if current_token is not None and current_token != blank_id: spans.append((current_token, start_frame, len(path))) return spans def __prepare_transcript( self, transcript: str | list[str] ) -> tuple[str, list[str], str, list[int]]: """Prepare the transcript for alignment by normalizing and tokenizing.""" if isinstance(transcript, str): full_transcript = transcript target_segments = transcript.split() else: target_segments = transcript full_transcript = " ".join(transcript) text_upper = full_transcript.upper() text_nfd = unicodedata.normalize("NFD", text_upper) text_no_diacritics = "".join( char for char in text_nfd if unicodedata.category(char) != UNICODE_CATEGORY_NONSPACING_MARK ) text_clean = "".join( c for c in text_no_diacritics if c.isalnum() or c == ASCII_SPACE ) clean_text = text_clean.replace(ASCII_SPACE, self.ctc_separator) valid_tokens = [ cast(int, self.tokenizer.convert_tokens_to_ids(c)) for c in clean_text if c in self.vocab ] if not valid_tokens: raise ValueError("Transcript contains no valid characters for this model.") return full_transcript, target_segments, clean_text, valid_tokens def __perform_forced_alignment( self, waveform: torch.Tensor, original_sr: int, valid_tokens: list[int] ) -> tuple[torch.Tensor, torch.Tensor]: """Perform forced alignment between the audio waveform and the transcript.""" emissions_gpu = self.__compute_emissions(waveform, original_sr).squeeze(0) emissions_cpu = emissions_gpu.unsqueeze(0).cpu() targets_cpu = torch.tensor([valid_tokens], dtype=torch.int32).cpu() emission_lens_cpu = torch.tensor([emissions_gpu.size(0)]).cpu() target_lens_cpu = torch.tensor([len(valid_tokens)]).cpu() aligned_tokens, _ = forced_align( emissions_cpu, targets_cpu, emission_lens_cpu, target_lens_cpu, blank=self.blank_id, ) alignment_path = aligned_tokens[0] return alignment_path, emissions_gpu def __refine_token_spans( self, token_spans: list[tuple[int, int, int]], emissions_gpu: torch.Tensor, waveform: torch.Tensor, original_sr: int, ) -> list[dict[str, str | float | bool]]: """ Refine token boundary times using acoustic features. Fix 1 + Fix 2: Boundaries between adjacent character spans are resolved jointly. For each consecutive pair (span[i], span[i+1]), the raw gap is [span[i].end, span[i+1].start] in frame space. We convert that gap to seconds, compute its midpoint, and call __refine_boundary_smart once, centred on the midpoint with the full gap as the search window. The resulting single refined time is assigned as the end of span[i] AND the start of span[i+1], guaranteeing monotonicity and symmetric coverage of blank regions. Fix 4: The `boundary_refined` flag from __refine_boundary_smart is stored per character so callers can detect heterogeneous boundary quality. """ ratio = waveform.size(1) / emissions_gpu.size(0) numpy_wave = waveform.cpu().numpy().squeeze() total_samples = waveform.size(1) n = len(token_spans) if n == 0: return [] # --- Pass 1: convert all raw frame spans to seconds ------------------- raw_starts: list[float] = [] raw_ends: list[float] = [] confs: list[float] = [] for sp_token, sp_start, sp_end in token_spans: raw_starts.append((sp_start * ratio) / original_sr) raw_ends.append((sp_end * ratio) / original_sr) conf = torch.exp(emissions_gpu[sp_start:sp_end, sp_token]).mean().item() confs.append(conf) # --- Pass 2: resolve shared inter-character boundaries jointly -------- # refined_boundaries[i] is the shared boundary between span[i] and span[i+1]. # len == n-1. refined_boundaries: list[tuple[float, bool]] = [] for i in range(n - 1): gap_left = raw_ends[i] # end of left character span (seconds) gap_right = raw_starts[i + 1] # start of right character span (seconds) # Midpoint of the blank gap as the search anchor (Fix 1). gap_mid = (gap_left + gap_right) / 2.0 refined_time, was_refined = self.__refine_boundary_smart( numpy_wave, original_sr, candidate_time=gap_mid, left_time=gap_left, right_time=gap_right, ) # Safety clamp: must stay within [gap_left, gap_right] to prevent # the refined boundary from drifting into neighbouring phonemes. refined_time = float(np.clip(refined_time, gap_left, gap_right)) refined_boundaries.append((refined_time, was_refined)) # --- Pass 3: refine leading edge of first span and trailing edge of last span --- # These have no neighbour on one side, so we use the original ±80 ms window. first_start, first_refined = self.__refine_boundary_smart( numpy_wave, original_sr, raw_starts[0] ) last_end, last_refined = self.__refine_boundary_smart( numpy_wave, original_sr, min(raw_ends[-1], total_samples / original_sr) ) # --- Pass 4: assemble refined character records ----------------------- refined_chars: list[dict[str, str | float | bool]] = [] for i, (sp_token, _sp_start, _sp_end) in enumerate(token_spans): # Start time: refined leading edge (first span) or shared boundary with left neighbour if i == 0: start_time = first_start start_refined = first_refined else: start_time, start_refined = refined_boundaries[i - 1] # End time: shared boundary with right neighbour, or refined trailing edge (last span) if i == n - 1: end_time = last_end end_refined = last_refined else: end_time, end_refined = refined_boundaries[i] boundary_refined = start_refined and end_refined char = cast(str, self.tokenizer.convert_ids_to_tokens(sp_token)) refined_chars.append({ "char": char, "start": start_time, "end": end_time, "confidence": confs[i], "boundary_refined": boundary_refined, }) return refined_chars def __aggregate_chars_to_segments( self, char_segments: list[dict[str, str | float | bool]], target_segments: list[str], ) -> list[AudioSegment]: """Aggregate character-level segments into user-defined token segments.""" final_segments = [] current_char_idx = 0 for segment_text in target_segments: clean_target = self.normalize_text(segment_text) if not clean_target: continue start_time = None end_time = None seg_confs = [] all_refined = True found_chars = 0 while found_chars < len(clean_target) and current_char_idx < len( char_segments ): seg = char_segments[current_char_idx] seg_char = cast(str, seg["char"]).replace(self.ctc_separator, "") if seg_char == clean_target[found_chars]: if start_time is None: start_time = seg["start"] end_time = seg["end"] seg_confs.append(seg["confidence"]) if not seg.get("boundary_refined", True): all_refined = False found_chars += 1 current_char_idx += 1 if start_time is not None: if end_time is None: end_time = cast(float, start_time) + 0.1 avg_conf = sum(seg_confs) / len(seg_confs) if seg_confs else 0.0 audio_seg = AudioSegment( token=segment_text, start_time=cast(float, start_time), end_time=cast(float, end_time), confidence=avg_conf, audio_format="wav", ) audio_seg.boundary_refined = all_refined final_segments.append(audio_seg) return final_segments def __set_segment_indices( self, final_segments: list[AudioSegment], waveform: torch.Tensor, original_sr: int, ) -> torch.Tensor: """Set the start_sample and end_sample indices for each segment based on the refined start_time and end_time.""" cpu_waveform = waveform.cpu() if cpu_waveform.dim() == 1: cpu_waveform = cpu_waveform.unsqueeze(0) for seg in final_segments: start_sample = int(seg.start_time * original_sr) end_sample = int(seg.end_time * original_sr) min_duration = int(0.05 * original_sr) if end_sample - start_sample < min_duration: end_sample = start_sample + min_duration start_sample = max(0, start_sample) end_sample = min(cpu_waveform.size(1), end_sample) seg.sample_rate = original_sr seg.start_sample = start_sample seg.end_sample = end_sample return cpu_waveform def __attach_audio_to_segments( self, final_segments: list[AudioSegment], waveform: torch.Tensor, original_sr: int, attach_audio: bool = True, ) -> None: cpu_waveform = self.__set_segment_indices(final_segments, waveform, original_sr) if not attach_audio: return for seg in final_segments: segment_tensor = cpu_waveform[:, seg.start_sample : seg.end_sample] seg.audio = self.__save_wav_mem(segment_tensor, original_sr)
[docs] def attach_audio_to_segments( self, segments: list[AudioSegment], audio_content: bytes | None = None, waveform: torch.Tensor | None = None, original_sr: int | None = None, audio_format: str = "mp3", ) -> None: """Attach audio bytes to existing segments using provided audio input. This can be used when segments are generated without audio and need to be enriched later.""" if audio_content is None and (waveform is None or original_sr is None): raise ValueError( "Either audio_content or both waveform and original_sr must be provided." ) if audio_content is not None: waveform, original_sr = TorchAudioHandler.from_bytes( audio_content, audio_format=audio_format ) original_sr = cast(int, original_sr) waveform = cast(torch.Tensor, waveform) self.__attach_audio_to_segments( segments, waveform, original_sr, attach_audio=True )
def __call__( self, transcript: str | list[str], audio_content: bytes | None = None, waveform: torch.Tensor | None = None, original_sr: int | None = None, audio_format: str = "mp3", attach_audio: bool = False, ) -> list[AudioSegment]: """Align the provided transcript to the audio and return a list of AudioSegments with refined boundaries. Args: transcript: The transcript to align, either as a single string or a list of token strings. audio_content: Raw audio bytes (optional if waveform and original_sr are provided). waveform: Pre-loaded audio waveform tensor (optional if audio_content is provided). original_sr: Original sample rate of the audio (required if waveform is provided). audio_format: Format of the input audio bytes (default: "mp3"). attach_audio: Whether to attach raw audio bytes to each segment (default: False). Returns: A list of AudioSegment instances with aligned tokens and refined timestamps. If attach_audio is True, each segment will also include the corresponding audio bytes in WAV format. Raises: ValueError: If neither audio_content nor both waveform and original_sr are provided, or if the transcript contains no valid characters for the model. """ if audio_content is None and (waveform is None or original_sr is None): raise ValueError( "Either audio_content or both waveform and original_sr must be provided." ) if audio_content is not None: waveform, original_sr = TorchAudioHandler.from_bytes( audio_content, audio_format=audio_format ) original_sr = cast(int, original_sr) waveform = cast(torch.Tensor, waveform) _, target_segments, clean_text, valid_tokens = self.__prepare_transcript( transcript ) logger.debug("Aligning to transcript: '%s'", clean_text) alignment_path, emissions_gpu = self.__perform_forced_alignment( waveform, original_sr, valid_tokens ) token_spans = self.__merge_tokens(alignment_path, self.blank_id) refined_chars = self.__refine_token_spans( token_spans, emissions_gpu, waveform, original_sr ) final_segments = self.__aggregate_chars_to_segments( refined_chars, target_segments ) if attach_audio: self.__attach_audio_to_segments(final_segments, waveform, original_sr) else: self.__set_segment_indices(final_segments, waveform, original_sr) return final_segments