|
|
"""ASR pipeline for audio-to-text transcription with optional timestamps and diarization.""" |
|
|
|
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import transformers |
|
|
|
|
|
try: |
|
|
from .alignment import ForcedAligner |
|
|
from .asr_modeling import ASRModel |
|
|
from .diarization import SpeakerDiarizer |
|
|
except ImportError: |
|
|
from alignment import ForcedAligner |
|
|
from asr_modeling import ASRModel |
|
|
from diarization import SpeakerDiarizer |
|
|
|
|
|
|
|
|
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"] |
|
|
|
|
|
|
|
|
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): |
|
|
"""ASR Pipeline for audio-to-text transcription.""" |
|
|
|
|
|
model: ASRModel |
|
|
|
|
|
def __init__(self, model: ASRModel, **kwargs): |
|
|
"""Initialize ASR pipeline. |
|
|
|
|
|
Args: |
|
|
model: ASRModel instance for transcription |
|
|
**kwargs: Additional arguments (feature_extractor, tokenizer, device) |
|
|
""" |
|
|
feature_extractor = kwargs.pop("feature_extractor", None) |
|
|
tokenizer = kwargs.pop("tokenizer", model.tokenizer) |
|
|
|
|
|
if feature_extractor is None: |
|
|
feature_extractor = model.get_processor().feature_extractor |
|
|
|
|
|
super().__init__( |
|
|
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs |
|
|
) |
|
|
self._current_audio = None |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
"""Intercept our custom parameters before parent class validates them.""" |
|
|
|
|
|
kwargs.pop("return_timestamps", None) |
|
|
kwargs.pop("return_speakers", None) |
|
|
kwargs.pop("num_speakers", None) |
|
|
kwargs.pop("min_speakers", None) |
|
|
kwargs.pop("max_speakers", None) |
|
|
kwargs.pop("hf_token", None) |
|
|
kwargs.pop("user_prompt", None) |
|
|
kwargs.pop("diarization_backend", None) |
|
|
|
|
|
return super()._sanitize_parameters(**kwargs) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
inputs, |
|
|
**kwargs, |
|
|
): |
|
|
"""Transcribe audio with optional word-level timestamps and speaker diarization. |
|
|
|
|
|
Args: |
|
|
inputs: Audio input (file path, dict with array/sampling_rate, etc.) |
|
|
return_timestamps: If True, return word-level timestamps using forced alignment |
|
|
return_speakers: If True, return speaker labels for each word |
|
|
user_prompt: Custom transcription prompt (default: "Transcribe: ") |
|
|
num_speakers: Exact number of speakers (if known, for diarization) |
|
|
min_speakers: Minimum number of speakers (for diarization) |
|
|
max_speakers: Maximum number of speakers (for diarization) |
|
|
**kwargs: Additional arguments passed to the pipeline |
|
|
|
|
|
Returns: |
|
|
Dict with 'text' key, 'words' key if return_timestamps=True, |
|
|
and speaker labels on words if return_speakers=True |
|
|
""" |
|
|
|
|
|
return_timestamps = kwargs.pop("return_timestamps", False) |
|
|
return_speakers = kwargs.pop("return_speakers", False) |
|
|
user_prompt = kwargs.pop("user_prompt", None) |
|
|
diarization_params = { |
|
|
"num_speakers": kwargs.pop("num_speakers", None), |
|
|
"min_speakers": kwargs.pop("min_speakers", None), |
|
|
"max_speakers": kwargs.pop("max_speakers", None), |
|
|
} |
|
|
|
|
|
if return_speakers: |
|
|
return_timestamps = True |
|
|
|
|
|
|
|
|
original_prompt = None |
|
|
if user_prompt: |
|
|
original_prompt = self.model.TRANSCRIBE_PROMPT |
|
|
self.model.TRANSCRIBE_PROMPT = user_prompt |
|
|
|
|
|
|
|
|
if return_timestamps or return_speakers: |
|
|
self._current_audio = self._extract_audio(inputs) |
|
|
|
|
|
|
|
|
result = super().__call__(inputs, **kwargs) |
|
|
|
|
|
|
|
|
if return_timestamps and self._current_audio is not None: |
|
|
text = result.get("text", "") |
|
|
if text: |
|
|
try: |
|
|
words = ForcedAligner.align( |
|
|
self._current_audio["array"], |
|
|
text, |
|
|
sample_rate=self._current_audio.get("sampling_rate", 16000), |
|
|
) |
|
|
result["words"] = words |
|
|
except Exception as e: |
|
|
result["words"] = [] |
|
|
result["timestamp_error"] = str(e) |
|
|
else: |
|
|
result["words"] = [] |
|
|
|
|
|
|
|
|
if return_speakers and self._current_audio is not None: |
|
|
try: |
|
|
|
|
|
speaker_segments = SpeakerDiarizer.diarize( |
|
|
self._current_audio["array"], |
|
|
sample_rate=self._current_audio.get("sampling_rate", 16000), |
|
|
**{k: v for k, v in diarization_params.items() if v is not None}, |
|
|
) |
|
|
result["speaker_segments"] = speaker_segments |
|
|
|
|
|
|
|
|
if result.get("words"): |
|
|
result["words"] = SpeakerDiarizer.assign_speakers_to_words( |
|
|
result["words"], |
|
|
speaker_segments, |
|
|
) |
|
|
except Exception as e: |
|
|
result["speaker_segments"] = [] |
|
|
result["diarization_error"] = str(e) |
|
|
|
|
|
|
|
|
self._current_audio = None |
|
|
if original_prompt is not None: |
|
|
self.model.TRANSCRIBE_PROMPT = original_prompt |
|
|
|
|
|
return result |
|
|
|
|
|
def _extract_audio(self, inputs) -> dict | None: |
|
|
"""Extract audio array from various input formats using HF utilities.""" |
|
|
from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
if "array" in inputs: |
|
|
return { |
|
|
"array": inputs["array"], |
|
|
"sampling_rate": inputs.get("sampling_rate", 16000), |
|
|
} |
|
|
if "raw" in inputs: |
|
|
return { |
|
|
"array": inputs["raw"], |
|
|
"sampling_rate": inputs.get("sampling_rate", 16000), |
|
|
} |
|
|
elif isinstance(inputs, str): |
|
|
|
|
|
with Path(inputs).open("rb") as f: |
|
|
audio = ffmpeg_read(f.read(), sampling_rate=16000) |
|
|
return {"array": audio, "sampling_rate": 16000} |
|
|
elif isinstance(inputs, bytes): |
|
|
audio = ffmpeg_read(inputs, sampling_rate=16000) |
|
|
return {"array": audio, "sampling_rate": 16000} |
|
|
elif isinstance(inputs, np.ndarray): |
|
|
return {"array": inputs, "sampling_rate": 16000} |
|
|
|
|
|
return None |
|
|
|
|
|
def preprocess(self, inputs, **preprocess_params): |
|
|
"""Preprocess audio inputs for the model. |
|
|
|
|
|
Args: |
|
|
inputs: Audio input (dict with array, file path, etc.) |
|
|
**preprocess_params: Additional preprocessing parameters |
|
|
|
|
|
Yields: |
|
|
Model input dicts with input_features and attention_mask |
|
|
""" |
|
|
|
|
|
if isinstance(inputs, dict) and "array" in inputs: |
|
|
inputs = { |
|
|
"raw": inputs["array"], |
|
|
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate), |
|
|
} |
|
|
|
|
|
for item in super().preprocess(inputs, **preprocess_params): |
|
|
if "is_last" not in item: |
|
|
item["is_last"] = True |
|
|
yield item |
|
|
|
|
|
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: |
|
|
"""Run model forward pass to generate transcription. |
|
|
|
|
|
Args: |
|
|
model_inputs: Dict with input_features and attention_mask |
|
|
**generate_kwargs: Generation parameters |
|
|
|
|
|
Returns: |
|
|
Dict with generated token IDs |
|
|
""" |
|
|
|
|
|
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True |
|
|
|
|
|
input_features = model_inputs["input_features"].to(self.model.device) |
|
|
audio_attention_mask = model_inputs["attention_mask"].to(self.model.device) |
|
|
|
|
|
generated_ids = self.model.generate( |
|
|
input_features=input_features, |
|
|
audio_attention_mask=audio_attention_mask, |
|
|
**generate_kwargs, |
|
|
) |
|
|
|
|
|
return {"tokens": generated_ids, "is_last": is_last} |
|
|
|
|
|
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]: |
|
|
"""Convert model output tokens to text. |
|
|
|
|
|
Args: |
|
|
model_outputs: Dict with 'tokens' key containing generated IDs |
|
|
**kwargs: Additional postprocessing parameters |
|
|
|
|
|
Returns: |
|
|
Dict with 'text' key containing transcription |
|
|
""" |
|
|
|
|
|
if isinstance(model_outputs, list): |
|
|
model_outputs = model_outputs[0] if model_outputs else {} |
|
|
|
|
|
tokens = model_outputs.get("tokens") |
|
|
if tokens is None: |
|
|
return super().postprocess(model_outputs, **kwargs) |
|
|
|
|
|
if torch.is_tensor(tokens): |
|
|
tokens = tokens.cpu() |
|
|
if tokens.dim() > 1: |
|
|
tokens = tokens[0] |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self, "model") and hasattr(self.model, "generation_config"): |
|
|
eos_ids = self.model.generation_config.eos_token_id |
|
|
if eos_ids is not None: |
|
|
eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids} |
|
|
tokens = [t for t in tokens.tolist() if t not in eos_set] |
|
|
|
|
|
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() |
|
|
|
|
|
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip() |
|
|
|
|
|
text = _truncate_repetitions(text) |
|
|
return {"text": text} |
|
|
|
|
|
|
|
|
def _truncate_repetitions(text: str, min_repeats: int = 3) -> str: |
|
|
"""Truncate repeated words/phrases/characters at end of text. |
|
|
|
|
|
Detects patterns like: |
|
|
- Repeated words: "the the the the" -> "the" |
|
|
- Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry" |
|
|
- Repeated characters: "444444" -> "4" |
|
|
|
|
|
Args: |
|
|
text: Input text to process |
|
|
min_repeats: Minimum repetitions to trigger truncation (default 3) |
|
|
|
|
|
Returns: |
|
|
Text with trailing repetitions removed |
|
|
""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
|
|
|
char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$") |
|
|
text = char_pattern.sub(r"\1", text) |
|
|
|
|
|
|
|
|
word_pattern = re.compile( |
|
|
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE |
|
|
) |
|
|
while word_pattern.search(text): |
|
|
text = word_pattern.sub(r"\1", text) |
|
|
|
|
|
|
|
|
|
|
|
words = text.split() |
|
|
if len(words) >= min_repeats * 2: |
|
|
|
|
|
for phrase_len in range(2, min(21, len(words) // min_repeats + 1)): |
|
|
|
|
|
phrase = " ".join(words[-phrase_len:]) |
|
|
|
|
|
phrase_escaped = re.escape(phrase) |
|
|
phrase_pattern = re.compile( |
|
|
r"(^|.*?\s)(" |
|
|
+ phrase_escaped |
|
|
+ r")(?:\s+" |
|
|
+ phrase_escaped |
|
|
+ r"){" |
|
|
+ str(min_repeats - 1) |
|
|
+ r",}\s*$", |
|
|
re.IGNORECASE, |
|
|
) |
|
|
match = phrase_pattern.match(text) |
|
|
if match: |
|
|
|
|
|
text = (match.group(1) + match.group(2)).strip() |
|
|
words = text.split() |
|
|
break |
|
|
|
|
|
return text |
|
|
|