File size: 6,210 Bytes
09eaf7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

import os
import torch
import numpy as np
from dotenv import load_dotenv
# from .step021_asr_whisperx import whisperx_transcribe_audio
from .step023_asr_higgs import higgs_transcribe_audio
from .utils import save_wav
import json
import librosa
from loguru import logger
load_dotenv()

def merge_segments(
    segments,
    max_gap: float = 0.40,          # seconds: merge if next.start - prev.end <= max_gap
    max_chars: int = 120,           # don't let merged text grow too long
    joiner: str = " "
):
    """
    Merge short/adjacent ASR segments into larger sentences.
    Safely handles empty text to avoid IndexError.
    Input:  List[{"start": float, "end": float, "text": str, "speaker": str}]
    Output: Same shape, merged.
    """
    if not segments:
        return []

    # sort just in case
    segs = sorted(segments, key=lambda x: float(x.get("start", 0.0)))

    # sentence-ending characters across languages
    ending = set(list(".!?。!?…」』”’】》") + ["]", ")", ")"])

    merged = []
    buffer = None

    def _clean_text(s):
        # normalize text to avoid None and trailing/leading spaces
        return (s or "").strip()

    for seg in segs:
        text = _clean_text(seg.get("text", ""))
        # Skip segments with no text at all
        if not text:
            continue

        start = float(seg.get("start", 0.0))
        end   = float(seg.get("end", start))
        spk   = seg.get("speaker", "SPEAKER_00")

        if buffer is None:
            buffer = {
                "start": start,
                "end": end,
                "text": text,
                "speaker": spk,
            }
            continue

        # Only merge if:
        #   1) temporal gap is small
        #   2) same speaker (optional but typical for diarized streams)
        #   3) previous buffer doesn't already end with sentence punctuation
        #   4) max length constraint respected
        gap = max(0.0, start - float(buffer["end"]))
        prev_text = _clean_text(buffer["text"])
        prev_last = prev_text[-1] if prev_text else ""
        prev_ends_sentence = prev_last in ending

        can_merge = (
            gap <= max_gap
            and spk == buffer["speaker"]
            and not prev_ends_sentence
            and (len(prev_text) + 1 + len(text) <= max_chars)
        )

        if can_merge:
            buffer["text"] = (prev_text + joiner + text).strip()
            buffer["end"] = max(float(buffer["end"]), end)
        else:
            merged.append(buffer)
            buffer = {
                "start": start,
                "end": end,
                "text": text,
                "speaker": spk,
            }

    if buffer is not None:
        merged.append(buffer)

    return merged


def generate_speaker_audio(folder, transcript):
    wav_path = os.path.join(folder, 'audio_vocals.wav')
    audio_data, samplerate = librosa.load(wav_path, sr=24000)
    speaker_dict = dict()
    length = len(audio_data)
    delay = 0.05
    for segment in transcript:
        start = max(0, int((segment['start'] - delay) * samplerate))
        end = min(int((segment['end']+delay) * samplerate), length)
        speaker_segment_audio = audio_data[start:end]
        speaker_dict[segment['speaker']] = np.concatenate((speaker_dict.get(
            segment['speaker'], np.zeros((0, ))), speaker_segment_audio))

    speaker_folder = os.path.join(folder, 'SPEAKER')
    if not os.path.exists(speaker_folder):
        os.makedirs(speaker_folder)
    
    for speaker, audio in speaker_dict.items():
        speaker_file_path = os.path.join(
            speaker_folder, f"{speaker}.wav")
        save_wav(audio, speaker_file_path)


def transcribe_audio(method, folder, model_name: str = 'large', download_root='models/ASR/whisper', device='auto', batch_size=32, diarization=True,min_speakers=None, max_speakers=None):
    if os.path.exists(os.path.join(folder, 'transcript.json')):
        logger.info(f'Transcript already exists in {folder}')
        return True
    
    wav_path = os.path.join(folder, 'audio_vocals.wav')
    if not os.path.exists(wav_path):
        return False
    
    logger.info(f'Transcribing {wav_path}')
    if device == 'auto':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # if method == 'WhisperX':
    #     transcript = whisperx_transcribe_audio(wav_path, model_name, download_root, device, batch_size, diarization, min_speakers, max_speakers)
    if method == 'FunASR':
        transcript = funasr_transcribe_audio(wav_path, device, batch_size, diarization)
    elif method == 'Higgs':
        transcript = higgs_transcribe_audio(wav_path, device, batch_size, diarization)
    else:
        logger.error('Invalid ASR method')
        raise ValueError('Invalid ASR method')

    transcript = merge_segments(transcript)
    with open(os.path.join(folder, 'transcript.json'), 'w', encoding='utf-8') as f:
        json.dump(transcript, f, indent=4, ensure_ascii=False)
    logger.info(f'Transcribed {wav_path} successfully, and saved to {os.path.join(folder, "transcript.json")}')
    generate_speaker_audio(folder, transcript)
    return transcript

def transcribe_all_audio_under_folder(folder, asr_method, whisper_model_name: str = 'large', device='auto', batch_size=32, diarization=False, min_speakers=None, max_speakers=None):
    transcribe_json = None
    for root, dirs, files in os.walk(folder):
        if 'audio_vocals.wav' in files and 'transcript.json' not in files:
            transcribe_json = transcribe_audio(asr_method, root, whisper_model_name, 'models/ASR/whisper', device, batch_size, diarization, min_speakers, max_speakers)
        elif 'transcript.json' in files:
            transcribe_json = json.load(open(os.path.join(root, 'transcript.json'), 'r', encoding='utf-8'))

            # logger.info(f'Transcript already exists in {root}')
    return f'Transcribed all audio under {folder}', transcribe_json

if __name__ == '__main__':
    _, transcribe_json = transcribe_all_audio_under_folder('videos', 'WhisperX')
    print(transcribe_json)
    # _, transcribe_json = transcribe_all_audio_under_folder('videos', 'FunASR')    
    # print(transcribe_json)