Spaces:
Running
on
Zero
Running
on
Zero
liuyang
commited on
Commit
·
28823e9
1
Parent(s):
5d33cf4
Update speaker diarization model and refactor WhisperTranscriber alignment process. Introduce align_timestamp method for improved word-level alignment and streamline segment handling. Adjusted print statements for clarity and removed unnecessary comments.
Browse files
app.py
CHANGED
|
@@ -433,7 +433,7 @@ def _preload_alignment_and_diarization_models():
|
|
| 433 |
torch.set_float32_matmul_precision('high')
|
| 434 |
|
| 435 |
_diarizer = Pipeline.from_pretrained(
|
| 436 |
-
"pyannote/speaker-diarization-
|
| 437 |
use_auth_token=os.getenv("HF_TOKEN"),
|
| 438 |
).to(torch.device("cuda"))
|
| 439 |
|
|
@@ -538,24 +538,24 @@ class WhisperTranscriber:
|
|
| 538 |
whisperx_model_name,
|
| 539 |
device=device,
|
| 540 |
compute_type=compute_type,
|
| 541 |
-
|
| 542 |
-
|
| 543 |
)
|
| 544 |
_whipser_x_transcribe_models[model_name] = whisper_model
|
| 545 |
print(f"WhisperX transcribe model '{model_name}' loaded successfully")
|
| 546 |
else:
|
| 547 |
whisper_model = _whipser_x_transcribe_models[model_name]
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
|
| 560 |
elif engine == "faster_whisper":
|
| 561 |
# Lazy-load Faster-Whisper model on first use
|
|
@@ -671,28 +671,24 @@ class WhisperTranscriber:
|
|
| 671 |
raise ValueError(f"Unknown engine '{engine}'. Supported: 'whisperx', 'faster_whisper'")
|
| 672 |
|
| 673 |
print(f"Detected language: {detected_language}, segments: {len(initial_segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
|
| 674 |
-
# Align with
|
| 675 |
segments = initial_segments
|
| 676 |
if detected_language in _whipser_x_align_models:
|
| 677 |
-
print(f"Performing WhisperX alignment for language '{detected_language}'...")
|
| 678 |
-
align_start = time.time()
|
| 679 |
try:
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
"cuda",
|
| 687 |
-
return_char_alignments=False
|
| 688 |
)
|
| 689 |
-
|
| 690 |
-
|
| 691 |
except Exception as e:
|
| 692 |
-
print(f"
|
| 693 |
else:
|
| 694 |
print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps")
|
| 695 |
-
|
| 696 |
# Process segments into the expected format
|
| 697 |
results = []
|
| 698 |
for seg in segments:
|
|
@@ -706,7 +702,7 @@ class WhisperTranscriber:
|
|
| 706 |
"probability": word.get("score", 1.0),
|
| 707 |
"speaker": "SPEAKER_00"
|
| 708 |
})
|
| 709 |
-
|
| 710 |
results.append({
|
| 711 |
"start": float(seg.get("start", 0.0)) + float(base_offset_s),
|
| 712 |
"end": float(seg.get("end", 0.0)) + float(base_offset_s),
|
|
@@ -714,13 +710,107 @@ class WhisperTranscriber:
|
|
| 714 |
"speaker": "SPEAKER_00",
|
| 715 |
"avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
|
| 716 |
"words": words_list,
|
| 717 |
-
"duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0))
|
|
|
|
| 718 |
})
|
| 719 |
|
| 720 |
print(results)
|
| 721 |
transcription_time = time.time() - start_time
|
| 722 |
print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
|
| 723 |
return results, detected_language
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
|
| 725 |
# Removed audio cutting; transcription is done once on the full (preprocessed) audio
|
| 726 |
|
|
|
|
| 433 |
torch.set_float32_matmul_precision('high')
|
| 434 |
|
| 435 |
_diarizer = Pipeline.from_pretrained(
|
| 436 |
+
"pyannote/speaker-diarization-community-1",
|
| 437 |
use_auth_token=os.getenv("HF_TOKEN"),
|
| 438 |
).to(torch.device("cuda"))
|
| 439 |
|
|
|
|
| 538 |
whisperx_model_name,
|
| 539 |
device=device,
|
| 540 |
compute_type=compute_type,
|
| 541 |
+
download_root=CACHE_ROOT,
|
| 542 |
+
asr_options=transcribe_options
|
| 543 |
)
|
| 544 |
_whipser_x_transcribe_models[model_name] = whisper_model
|
| 545 |
print(f"WhisperX transcribe model '{model_name}' loaded successfully")
|
| 546 |
else:
|
| 547 |
whisper_model = _whipser_x_transcribe_models[model_name]
|
| 548 |
+
|
| 549 |
+
print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...")
|
| 550 |
+
result = whisper_model.transcribe(
|
| 551 |
+
audio,
|
| 552 |
+
language=language,
|
| 553 |
+
batch_size=batch_size,
|
| 554 |
+
#initial_prompt=prompt,
|
| 555 |
+
#task="translate" if translate else "transcribe"
|
| 556 |
+
)
|
| 557 |
+
detected_language = result.get("language", detected_language)
|
| 558 |
+
initial_segments = result.get("segments", [])
|
| 559 |
|
| 560 |
elif engine == "faster_whisper":
|
| 561 |
# Lazy-load Faster-Whisper model on first use
|
|
|
|
| 671 |
raise ValueError(f"Unknown engine '{engine}'. Supported: 'whisperx', 'faster_whisper'")
|
| 672 |
|
| 673 |
print(f"Detected language: {detected_language}, segments: {len(initial_segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
|
| 674 |
+
# Align with centralized alignment method when available
|
| 675 |
segments = initial_segments
|
| 676 |
if detected_language in _whipser_x_align_models:
|
|
|
|
|
|
|
| 677 |
try:
|
| 678 |
+
align_out = self.align_timestamp(
|
| 679 |
+
audio_url=audio_path,
|
| 680 |
+
text=None,
|
| 681 |
+
language=detected_language,
|
| 682 |
+
engine="whisperx",
|
| 683 |
+
options={"segments": initial_segments},
|
|
|
|
|
|
|
| 684 |
)
|
| 685 |
+
if isinstance(align_out, dict) and align_out.get("segments"):
|
| 686 |
+
segments = align_out["segments"]
|
| 687 |
except Exception as e:
|
| 688 |
+
print(f"Alignment via align_timestamp failed: {e}, using original timestamps")
|
| 689 |
else:
|
| 690 |
print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps")
|
| 691 |
+
|
| 692 |
# Process segments into the expected format
|
| 693 |
results = []
|
| 694 |
for seg in segments:
|
|
|
|
| 702 |
"probability": word.get("score", 1.0),
|
| 703 |
"speaker": "SPEAKER_00"
|
| 704 |
})
|
| 705 |
+
|
| 706 |
results.append({
|
| 707 |
"start": float(seg.get("start", 0.0)) + float(base_offset_s),
|
| 708 |
"end": float(seg.get("end", 0.0)) + float(base_offset_s),
|
|
|
|
| 710 |
"speaker": "SPEAKER_00",
|
| 711 |
"avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
|
| 712 |
"words": words_list,
|
| 713 |
+
"duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0)),
|
| 714 |
+
"language": detected_language,
|
| 715 |
})
|
| 716 |
|
| 717 |
print(results)
|
| 718 |
transcription_time = time.time() - start_time
|
| 719 |
print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
|
| 720 |
return results, detected_language
|
| 721 |
+
|
| 722 |
+
@spaces.GPU # alignment requires GPU
|
| 723 |
+
def align_timestamp(self, audio_url, text, language, engine="whisperx", options: dict = None):
|
| 724 |
+
"""Return word-level alignment for the given text/audio using the specified engine.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
audio_url: Path or URL to the audio file.
|
| 728 |
+
text: String text to align. If options contains 'segments', this can be None.
|
| 729 |
+
language: Language code (e.g., 'en'). Must be supported by WhisperX align models.
|
| 730 |
+
engine: Currently only 'whisperx' is supported.
|
| 731 |
+
options: Optional dict. Recognized keys:
|
| 732 |
+
- 'segments': list of {start, end, text} to align (preferred for segment-aware alignment)
|
| 733 |
+
|
| 734 |
+
Returns:
|
| 735 |
+
dict with keys:
|
| 736 |
+
- 'segments': aligned segments including word timings (if available)
|
| 737 |
+
- 'words': flat list of aligned words across all segments
|
| 738 |
+
"""
|
| 739 |
+
global _whipser_x_align_models
|
| 740 |
+
|
| 741 |
+
if engine != "whisperx":
|
| 742 |
+
raise ValueError(f"align_timestamp engine '{engine}' not supported. Only 'whisperx' is supported")
|
| 743 |
+
|
| 744 |
+
if language not in _whipser_x_align_models:
|
| 745 |
+
raise ValueError(f"No WhisperX alignment model available for language '{language}'")
|
| 746 |
+
|
| 747 |
+
# Resolve audio path (download if URL)
|
| 748 |
+
local_path = None
|
| 749 |
+
tmp_file = None
|
| 750 |
+
try:
|
| 751 |
+
if isinstance(audio_url, str) and audio_url.startswith(("http://", "https://")):
|
| 752 |
+
resp = requests.get(audio_url, stream=True, timeout=60)
|
| 753 |
+
resp.raise_for_status()
|
| 754 |
+
tmp_f = tempfile.NamedTemporaryFile(suffix=".audio", delete=False)
|
| 755 |
+
for chunk in resp.iter_content(chunk_size=8192):
|
| 756 |
+
if chunk:
|
| 757 |
+
tmp_f.write(chunk)
|
| 758 |
+
tmp_f.flush()
|
| 759 |
+
tmp_f.close()
|
| 760 |
+
tmp_file = tmp_f.name
|
| 761 |
+
local_path = tmp_file
|
| 762 |
+
else:
|
| 763 |
+
local_path = audio_url
|
| 764 |
+
|
| 765 |
+
# Load audio and decide segments to align
|
| 766 |
+
audio = whisperx.load_audio(local_path)
|
| 767 |
+
sr = 16000.0 # whisperx loads at 16k
|
| 768 |
+
audio_duration = float(len(audio)) / sr if hasattr(audio, "__len__") else None
|
| 769 |
+
segments_to_align = None
|
| 770 |
+
|
| 771 |
+
if options and isinstance(options, dict) and options.get("segments"):
|
| 772 |
+
segments_to_align = options.get("segments")
|
| 773 |
+
else:
|
| 774 |
+
if not text or not str(text).strip():
|
| 775 |
+
raise ValueError("align_timestamp requires 'text' when 'segments' are not provided in options")
|
| 776 |
+
if audio_duration is None:
|
| 777 |
+
raise ValueError("Could not determine audio duration for alignment")
|
| 778 |
+
segments_to_align = [{
|
| 779 |
+
"text": str(text),
|
| 780 |
+
"start": 0.0,
|
| 781 |
+
"end": audio_duration,
|
| 782 |
+
}]
|
| 783 |
+
|
| 784 |
+
# Perform alignment
|
| 785 |
+
align_info = _whipser_x_align_models[language]
|
| 786 |
+
aligned = whisperx.align(
|
| 787 |
+
segments_to_align,
|
| 788 |
+
align_info["model"],
|
| 789 |
+
align_info["metadata"],
|
| 790 |
+
audio,
|
| 791 |
+
"cuda",
|
| 792 |
+
return_char_alignments=False,
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
aligned_segments = aligned.get("segments", segments_to_align)
|
| 796 |
+
words_flat = []
|
| 797 |
+
for seg in aligned_segments:
|
| 798 |
+
for w in seg.get("words", []) or []:
|
| 799 |
+
words_flat.append({
|
| 800 |
+
"start": float(w.get("start", 0.0)),
|
| 801 |
+
"end": float(w.get("end", 0.0)),
|
| 802 |
+
"word": w.get("word", ""),
|
| 803 |
+
"probability": w.get("score", 1.0)
|
| 804 |
+
})
|
| 805 |
+
|
| 806 |
+
return {"segments": aligned_segments, "words": words_flat, "language": language}
|
| 807 |
+
|
| 808 |
+
finally:
|
| 809 |
+
if tmp_file:
|
| 810 |
+
try:
|
| 811 |
+
os.unlink(tmp_file)
|
| 812 |
+
except Exception:
|
| 813 |
+
pass
|
| 814 |
|
| 815 |
# Removed audio cutting; transcription is done once on the full (preprocessed) audio
|
| 816 |
|