liuyang commited on
Commit
28823e9
·
1 Parent(s): 5d33cf4

Update speaker diarization model and refactor WhisperTranscriber alignment process. Introduce align_timestamp method for improved word-level alignment and streamline segment handling. Adjusted print statements for clarity and removed unnecessary comments.

Browse files
Files changed (1) hide show
  1. app.py +121 -31
app.py CHANGED
@@ -433,7 +433,7 @@ def _preload_alignment_and_diarization_models():
433
  torch.set_float32_matmul_precision('high')
434
 
435
  _diarizer = Pipeline.from_pretrained(
436
- "pyannote/speaker-diarization-3.1",
437
  use_auth_token=os.getenv("HF_TOKEN"),
438
  ).to(torch.device("cuda"))
439
 
@@ -538,24 +538,24 @@ class WhisperTranscriber:
538
  whisperx_model_name,
539
  device=device,
540
  compute_type=compute_type,
541
- download_root=CACHE_ROOT,
542
- asr_options=transcribe_options
543
  )
544
  _whipser_x_transcribe_models[model_name] = whisper_model
545
  print(f"WhisperX transcribe model '{model_name}' loaded successfully")
546
  else:
547
  whisper_model = _whipser_x_transcribe_models[model_name]
548
-
549
- print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...")
550
- result = whisper_model.transcribe(
551
- audio,
552
- language=language,
553
- batch_size=batch_size,
554
- #initial_prompt=prompt,
555
- #task="translate" if translate else "transcribe"
556
- )
557
- detected_language = result.get("language", detected_language)
558
- initial_segments = result.get("segments", [])
559
 
560
  elif engine == "faster_whisper":
561
  # Lazy-load Faster-Whisper model on first use
@@ -671,28 +671,24 @@ class WhisperTranscriber:
671
  raise ValueError(f"Unknown engine '{engine}'. Supported: 'whisperx', 'faster_whisper'")
672
 
673
  print(f"Detected language: {detected_language}, segments: {len(initial_segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
674
- # Align with WhisperX if supported for detected language (always attempt when available)
675
  segments = initial_segments
676
  if detected_language in _whipser_x_align_models:
677
- print(f"Performing WhisperX alignment for language '{detected_language}'...")
678
- align_start = time.time()
679
  try:
680
- align_info = _whipser_x_align_models[detected_language]
681
- align_result = whisperx.align(
682
- initial_segments,
683
- align_info["model"],
684
- align_info["metadata"],
685
- audio,
686
- "cuda",
687
- return_char_alignments=False
688
  )
689
- segments = align_result.get("segments", segments)
690
- print(f"WhisperX alignment completed in {time.time() - align_start:.2f} seconds")
691
  except Exception as e:
692
- print(f"WhisperX alignment failed: {e}, using original timestamps")
693
  else:
694
  print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps")
695
-
696
  # Process segments into the expected format
697
  results = []
698
  for seg in segments:
@@ -706,7 +702,7 @@ class WhisperTranscriber:
706
  "probability": word.get("score", 1.0),
707
  "speaker": "SPEAKER_00"
708
  })
709
-
710
  results.append({
711
  "start": float(seg.get("start", 0.0)) + float(base_offset_s),
712
  "end": float(seg.get("end", 0.0)) + float(base_offset_s),
@@ -714,13 +710,107 @@ class WhisperTranscriber:
714
  "speaker": "SPEAKER_00",
715
  "avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
716
  "words": words_list,
717
- "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0))
 
718
  })
719
 
720
  print(results)
721
  transcription_time = time.time() - start_time
722
  print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
723
  return results, detected_language
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
 
725
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
726
 
 
433
  torch.set_float32_matmul_precision('high')
434
 
435
  _diarizer = Pipeline.from_pretrained(
436
+ "pyannote/speaker-diarization-community-1",
437
  use_auth_token=os.getenv("HF_TOKEN"),
438
  ).to(torch.device("cuda"))
439
 
 
538
  whisperx_model_name,
539
  device=device,
540
  compute_type=compute_type,
541
+ download_root=CACHE_ROOT,
542
+ asr_options=transcribe_options
543
  )
544
  _whipser_x_transcribe_models[model_name] = whisper_model
545
  print(f"WhisperX transcribe model '{model_name}' loaded successfully")
546
  else:
547
  whisper_model = _whipser_x_transcribe_models[model_name]
548
+
549
+ print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...")
550
+ result = whisper_model.transcribe(
551
+ audio,
552
+ language=language,
553
+ batch_size=batch_size,
554
+ #initial_prompt=prompt,
555
+ #task="translate" if translate else "transcribe"
556
+ )
557
+ detected_language = result.get("language", detected_language)
558
+ initial_segments = result.get("segments", [])
559
 
560
  elif engine == "faster_whisper":
561
  # Lazy-load Faster-Whisper model on first use
 
671
  raise ValueError(f"Unknown engine '{engine}'. Supported: 'whisperx', 'faster_whisper'")
672
 
673
  print(f"Detected language: {detected_language}, segments: {len(initial_segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
674
+ # Align with centralized alignment method when available
675
  segments = initial_segments
676
  if detected_language in _whipser_x_align_models:
 
 
677
  try:
678
+ align_out = self.align_timestamp(
679
+ audio_url=audio_path,
680
+ text=None,
681
+ language=detected_language,
682
+ engine="whisperx",
683
+ options={"segments": initial_segments},
 
 
684
  )
685
+ if isinstance(align_out, dict) and align_out.get("segments"):
686
+ segments = align_out["segments"]
687
  except Exception as e:
688
+ print(f"Alignment via align_timestamp failed: {e}, using original timestamps")
689
  else:
690
  print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps")
691
+
692
  # Process segments into the expected format
693
  results = []
694
  for seg in segments:
 
702
  "probability": word.get("score", 1.0),
703
  "speaker": "SPEAKER_00"
704
  })
705
+
706
  results.append({
707
  "start": float(seg.get("start", 0.0)) + float(base_offset_s),
708
  "end": float(seg.get("end", 0.0)) + float(base_offset_s),
 
710
  "speaker": "SPEAKER_00",
711
  "avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
712
  "words": words_list,
713
+ "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0)),
714
+ "language": detected_language,
715
  })
716
 
717
  print(results)
718
  transcription_time = time.time() - start_time
719
  print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
720
  return results, detected_language
721
+
722
+ @spaces.GPU # alignment requires GPU
723
+ def align_timestamp(self, audio_url, text, language, engine="whisperx", options: dict = None):
724
+ """Return word-level alignment for the given text/audio using the specified engine.
725
+
726
+ Args:
727
+ audio_url: Path or URL to the audio file.
728
+ text: String text to align. If options contains 'segments', this can be None.
729
+ language: Language code (e.g., 'en'). Must be supported by WhisperX align models.
730
+ engine: Currently only 'whisperx' is supported.
731
+ options: Optional dict. Recognized keys:
732
+ - 'segments': list of {start, end, text} to align (preferred for segment-aware alignment)
733
+
734
+ Returns:
735
+ dict with keys:
736
+ - 'segments': aligned segments including word timings (if available)
737
+ - 'words': flat list of aligned words across all segments
738
+ """
739
+ global _whipser_x_align_models
740
+
741
+ if engine != "whisperx":
742
+ raise ValueError(f"align_timestamp engine '{engine}' not supported. Only 'whisperx' is supported")
743
+
744
+ if language not in _whipser_x_align_models:
745
+ raise ValueError(f"No WhisperX alignment model available for language '{language}'")
746
+
747
+ # Resolve audio path (download if URL)
748
+ local_path = None
749
+ tmp_file = None
750
+ try:
751
+ if isinstance(audio_url, str) and audio_url.startswith(("http://", "https://")):
752
+ resp = requests.get(audio_url, stream=True, timeout=60)
753
+ resp.raise_for_status()
754
+ tmp_f = tempfile.NamedTemporaryFile(suffix=".audio", delete=False)
755
+ for chunk in resp.iter_content(chunk_size=8192):
756
+ if chunk:
757
+ tmp_f.write(chunk)
758
+ tmp_f.flush()
759
+ tmp_f.close()
760
+ tmp_file = tmp_f.name
761
+ local_path = tmp_file
762
+ else:
763
+ local_path = audio_url
764
+
765
+ # Load audio and decide segments to align
766
+ audio = whisperx.load_audio(local_path)
767
+ sr = 16000.0 # whisperx loads at 16k
768
+ audio_duration = float(len(audio)) / sr if hasattr(audio, "__len__") else None
769
+ segments_to_align = None
770
+
771
+ if options and isinstance(options, dict) and options.get("segments"):
772
+ segments_to_align = options.get("segments")
773
+ else:
774
+ if not text or not str(text).strip():
775
+ raise ValueError("align_timestamp requires 'text' when 'segments' are not provided in options")
776
+ if audio_duration is None:
777
+ raise ValueError("Could not determine audio duration for alignment")
778
+ segments_to_align = [{
779
+ "text": str(text),
780
+ "start": 0.0,
781
+ "end": audio_duration,
782
+ }]
783
+
784
+ # Perform alignment
785
+ align_info = _whipser_x_align_models[language]
786
+ aligned = whisperx.align(
787
+ segments_to_align,
788
+ align_info["model"],
789
+ align_info["metadata"],
790
+ audio,
791
+ "cuda",
792
+ return_char_alignments=False,
793
+ )
794
+
795
+ aligned_segments = aligned.get("segments", segments_to_align)
796
+ words_flat = []
797
+ for seg in aligned_segments:
798
+ for w in seg.get("words", []) or []:
799
+ words_flat.append({
800
+ "start": float(w.get("start", 0.0)),
801
+ "end": float(w.get("end", 0.0)),
802
+ "word": w.get("word", ""),
803
+ "probability": w.get("score", 1.0)
804
+ })
805
+
806
+ return {"segments": aligned_segments, "words": words_flat, "language": language}
807
+
808
+ finally:
809
+ if tmp_file:
810
+ try:
811
+ os.unlink(tmp_file)
812
+ except Exception:
813
+ pass
814
 
815
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
816