Spaces:
Build error
Build error
| import os | |
| import torch | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| # from .step021_asr_whisperx import whisperx_transcribe_audio | |
| from .step023_asr_higgs import higgs_transcribe_audio | |
| from .utils import save_wav | |
| import json | |
| import librosa | |
| from loguru import logger | |
| load_dotenv() | |
| def merge_segments( | |
| segments, | |
| max_gap: float = 0.40, # seconds: merge if next.start - prev.end <= max_gap | |
| max_chars: int = 120, # don't let merged text grow too long | |
| joiner: str = " " | |
| ): | |
| """ | |
| Merge short/adjacent ASR segments into larger sentences. | |
| Safely handles empty text to avoid IndexError. | |
| Input: List[{"start": float, "end": float, "text": str, "speaker": str}] | |
| Output: Same shape, merged. | |
| """ | |
| if not segments: | |
| return [] | |
| # sort just in case | |
| segs = sorted(segments, key=lambda x: float(x.get("start", 0.0))) | |
| # sentence-ending characters across languages | |
| ending = set(list(".!?。!?…」』”’】》") + ["]", ")", ")"]) | |
| merged = [] | |
| buffer = None | |
| def _clean_text(s): | |
| # normalize text to avoid None and trailing/leading spaces | |
| return (s or "").strip() | |
| for seg in segs: | |
| text = _clean_text(seg.get("text", "")) | |
| # Skip segments with no text at all | |
| if not text: | |
| continue | |
| start = float(seg.get("start", 0.0)) | |
| end = float(seg.get("end", start)) | |
| spk = seg.get("speaker", "SPEAKER_00") | |
| if buffer is None: | |
| buffer = { | |
| "start": start, | |
| "end": end, | |
| "text": text, | |
| "speaker": spk, | |
| } | |
| continue | |
| # Only merge if: | |
| # 1) temporal gap is small | |
| # 2) same speaker (optional but typical for diarized streams) | |
| # 3) previous buffer doesn't already end with sentence punctuation | |
| # 4) max length constraint respected | |
| gap = max(0.0, start - float(buffer["end"])) | |
| prev_text = _clean_text(buffer["text"]) | |
| prev_last = prev_text[-1] if prev_text else "" | |
| prev_ends_sentence = prev_last in ending | |
| can_merge = ( | |
| gap <= max_gap | |
| and spk == buffer["speaker"] | |
| and not prev_ends_sentence | |
| and (len(prev_text) + 1 + len(text) <= max_chars) | |
| ) | |
| if can_merge: | |
| buffer["text"] = (prev_text + joiner + text).strip() | |
| buffer["end"] = max(float(buffer["end"]), end) | |
| else: | |
| merged.append(buffer) | |
| buffer = { | |
| "start": start, | |
| "end": end, | |
| "text": text, | |
| "speaker": spk, | |
| } | |
| if buffer is not None: | |
| merged.append(buffer) | |
| return merged | |
| def generate_speaker_audio(folder, transcript): | |
| wav_path = os.path.join(folder, 'audio_vocals.wav') | |
| audio_data, samplerate = librosa.load(wav_path, sr=24000) | |
| speaker_dict = dict() | |
| length = len(audio_data) | |
| delay = 0.05 | |
| for segment in transcript: | |
| start = max(0, int((segment['start'] - delay) * samplerate)) | |
| end = min(int((segment['end']+delay) * samplerate), length) | |
| speaker_segment_audio = audio_data[start:end] | |
| speaker_dict[segment['speaker']] = np.concatenate((speaker_dict.get( | |
| segment['speaker'], np.zeros((0, ))), speaker_segment_audio)) | |
| speaker_folder = os.path.join(folder, 'SPEAKER') | |
| if not os.path.exists(speaker_folder): | |
| os.makedirs(speaker_folder) | |
| for speaker, audio in speaker_dict.items(): | |
| speaker_file_path = os.path.join( | |
| speaker_folder, f"{speaker}.wav") | |
| save_wav(audio, speaker_file_path) | |
| def transcribe_audio(method, folder, model_name: str = 'large', download_root='models/ASR/whisper', device='auto', batch_size=32, diarization=True,min_speakers=None, max_speakers=None): | |
| if os.path.exists(os.path.join(folder, 'transcript.json')): | |
| logger.info(f'Transcript already exists in {folder}') | |
| return True | |
| wav_path = os.path.join(folder, 'audio_vocals.wav') | |
| if not os.path.exists(wav_path): | |
| return False | |
| logger.info(f'Transcribing {wav_path}') | |
| if device == 'auto': | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # if method == 'WhisperX': | |
| # transcript = whisperx_transcribe_audio(wav_path, model_name, download_root, device, batch_size, diarization, min_speakers, max_speakers) | |
| if method == 'FunASR': | |
| transcript = funasr_transcribe_audio(wav_path, device, batch_size, diarization) | |
| elif method == 'Higgs': | |
| transcript = higgs_transcribe_audio(wav_path, device, batch_size, diarization) | |
| else: | |
| logger.error('Invalid ASR method') | |
| raise ValueError('Invalid ASR method') | |
| transcript = merge_segments(transcript) | |
| with open(os.path.join(folder, 'transcript.json'), 'w', encoding='utf-8') as f: | |
| json.dump(transcript, f, indent=4, ensure_ascii=False) | |
| logger.info(f'Transcribed {wav_path} successfully, and saved to {os.path.join(folder, "transcript.json")}') | |
| generate_speaker_audio(folder, transcript) | |
| return transcript | |
| def transcribe_all_audio_under_folder(folder, asr_method, whisper_model_name: str = 'large', device='auto', batch_size=32, diarization=False, min_speakers=None, max_speakers=None): | |
| transcribe_json = None | |
| for root, dirs, files in os.walk(folder): | |
| if 'audio_vocals.wav' in files and 'transcript.json' not in files: | |
| transcribe_json = transcribe_audio(asr_method, root, whisper_model_name, 'models/ASR/whisper', device, batch_size, diarization, min_speakers, max_speakers) | |
| elif 'transcript.json' in files: | |
| transcribe_json = json.load(open(os.path.join(root, 'transcript.json'), 'r', encoding='utf-8')) | |
| # logger.info(f'Transcript already exists in {root}') | |
| return f'Transcribed all audio under {folder}', transcribe_json | |
| if __name__ == '__main__': | |
| _, transcribe_json = transcribe_all_audio_under_folder('videos', 'WhisperX') | |
| print(transcribe_json) | |
| # _, transcribe_json = transcribe_all_audio_under_folder('videos', 'FunASR') | |
| # print(transcribe_json) |