Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import boto3 | |
from botocore.exceptions import NoCredentialsError, ClientError | |
from botocore.client import Config | |
import os, pathlib | |
CACHE_ROOT = "/home/user/app/cache" # any folder you own | |
os.environ.update( | |
TORCH_HOME = f"{CACHE_ROOT}/torch", | |
XDG_CACHE_HOME = f"{CACHE_ROOT}/xdg", # torch fallback | |
PYANNOTE_CACHE = f"{CACHE_ROOT}/pyannote", | |
HF_HOME = f"{CACHE_ROOT}/huggingface", | |
TRANSFORMERS_CACHE= f"{CACHE_ROOT}/transformers", | |
MPLCONFIGDIR = f"{CACHE_ROOT}/mpl", | |
) | |
INITIAL_PROMPT = ''' | |
Use normal punctuation; end sentences properly. | |
''' | |
# make sure the directories exist | |
for path in os.environ.values(): | |
pathlib.Path(path).mkdir(parents=True, exist_ok=True) | |
import gradio as gr | |
import torch | |
import torchaudio | |
import numpy as np | |
import pandas as pd | |
import time | |
import datetime | |
import re | |
import subprocess | |
import os | |
import tempfile | |
import spaces | |
from faster_whisper import WhisperModel, BatchedInferencePipeline | |
from faster_whisper.vad import VadOptions | |
import requests | |
import base64 | |
from pyannote.audio import Pipeline, Inference, Model | |
from pyannote.core import Segment | |
import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math | |
import json | |
import webrtcvad | |
spec = importlib.util.find_spec("nvidia.cudnn") | |
if spec is None: | |
sys.exit("β nvidia-cudnn-cu12 wheel not found. Run: pip install nvidia-cudnn-cu12") | |
cudnn_dir = pathlib.Path(spec.origin).parent / "lib" | |
cnn_so = cudnn_dir / "libcudnn_cnn.so.9" | |
try: | |
ctypes.CDLL(cnn_so, mode=ctypes.RTLD_GLOBAL) | |
print(f"β Pre-loaded {cnn_so}") | |
except OSError as e: | |
sys.exit(f"β Could not load {cnn_so} : {e}") | |
S3_ENDPOINT = os.getenv("S3_ENDPOINT") | |
S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") | |
S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") | |
# Function to upload file to Cloudflare R2 | |
def upload_data_to_r2(data, bucket_name, object_name, content_type='application/octet-stream'): | |
""" | |
Upload data directly to a Cloudflare R2 bucket. | |
:param data: Data to upload (bytes or string). | |
:param bucket_name: Name of the R2 bucket. | |
:param object_name: Name of the object to save in the bucket. | |
:param content_type: MIME type of the data. | |
:return: True if data was uploaded, else False. | |
""" | |
try: | |
# Convert string to bytes if necessary | |
if isinstance(data, str): | |
data = data.encode('utf-8') | |
# Initialize a session using Cloudflare R2 credentials | |
session = boto3.session.Session() | |
s3 = session.client('s3', | |
endpoint_url=f'https://{S3_ENDPOINT}', | |
aws_access_key_id=S3_ACCESS_KEY, | |
aws_secret_access_key=S3_SECRET_KEY, | |
config = Config(s3={"addressing_style": "virtual", 'payload_signing_enabled': False}, signature_version='v4', | |
request_checksum_calculation='when_required', | |
response_checksum_validation='when_required',), | |
) | |
# Upload the data to R2 bucket | |
s3.put_object( | |
Bucket=bucket_name, | |
Key=object_name, | |
Body=data, | |
ContentType=content_type, | |
ContentLength=len(data), # make length explicit to avoid streaming | |
) | |
print(f"Data uploaded to R2 bucket '{bucket_name}' as '{object_name}'") | |
return True | |
except NoCredentialsError: | |
print("Credentials not available") | |
return False | |
except ClientError as e: | |
print(f"Failed to upload data to R2 bucket: {e}") | |
return False | |
except Exception as e: | |
print(f"An unexpected error occurred: {e}") | |
return False | |
from huggingface_hub import snapshot_download | |
MODEL_REPO = "deepdml/faster-whisper-large-v3-turbo-ct2" # CT2 format | |
LOCAL_DIR = f"{CACHE_ROOT}/whisper_turbo" | |
# ----------------------------------------------------------------------------- | |
# Audio preprocess helper (from input_and_preprocess rule) | |
# ----------------------------------------------------------------------------- | |
TRIM_THRESHOLD_MS = 10_000 # 10 seconds | |
DEFAULT_PAD_MS = 250 # safety context around detected speech | |
FRAME_MS = 30 # VAD frame | |
HANG_MS = 240 # hangover (keep speech "on" after silence) | |
VAD_LEVEL = 2 # 0-3 | |
def _decode_chunk_to_pcm(task: dict) -> bytes: | |
"""Use ffmpeg to decode the chunk to s16le mono @ 16k PCM bytes.""" | |
src = task["source_uri"] | |
ing = task["ingest_recipe"] | |
seek = task["ffmpeg_seek"] | |
cmd = [ | |
"ffmpeg", "-nostdin", "-hide_banner", "-v", "error", | |
"-ss", f"{max(0.0, float(seek['pre_ss_sec'])):.3f}", | |
"-i", src, | |
"-map", "0:a:0", | |
"-ss", f"{float(seek['post_ss_sec']):.2f}", | |
"-t", f"{float(seek['t_sec']):.3f}", | |
] | |
# Optional L/R extraction | |
if ing.get("channel_extract_filter"): | |
cmd += ["-af", ing["channel_extract_filter"]] | |
# Force mono 16k s16le to stdout | |
cmd += ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", "-f", "s16le", "pipe:1"] | |
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
pcm, err = p.communicate() | |
if p.returncode != 0: | |
raise RuntimeError(f"ffmpeg failed: {err.decode('utf-8', 'ignore')}") | |
return pcm | |
def _find_head_tail_speech_ms( | |
pcm: bytes, | |
sr: int = 16000, | |
frame_ms: int = FRAME_MS, | |
vad_level: int = VAD_LEVEL, | |
hang_ms: int = HANG_MS, | |
): | |
"""Return (first_ms, last_ms) speech boundaries using webrtcvad with hangover.""" | |
if not pcm: | |
return None, None | |
vad = webrtcvad.Vad(int(vad_level)) | |
bpf = 2 # bytes per sample (s16) | |
samples_per_ms = sr // 1000 # 16 | |
bytes_per_frame = samples_per_ms * bpf * frame_ms | |
n_frames = len(pcm) // bytes_per_frame | |
if n_frames == 0: | |
return None, None | |
first_ms, last_ms = None, None | |
t_ms = 0 | |
in_speech = False | |
silence_run = 0 | |
view = memoryview(pcm)[: n_frames * bytes_per_frame] | |
for i in range(n_frames): | |
frame = view[i * bytes_per_frame : (i + 1) * bytes_per_frame] | |
if vad.is_speech(frame, sr): | |
if first_ms is None: | |
first_ms = t_ms | |
in_speech = True | |
silence_run = 0 | |
else: | |
if in_speech: | |
silence_run += frame_ms | |
if silence_run >= hang_ms: | |
last_ms = t_ms - (silence_run - hang_ms) | |
in_speech = False | |
silence_run = 0 | |
t_ms += frame_ms | |
if in_speech: | |
last_ms = t_ms | |
return first_ms, last_ms | |
def _write_wav(path: str, pcm: bytes, sr: int = 16000): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
with wave.open(path, "wb") as w: | |
w.setnchannels(1) | |
w.setsampwidth(2) # s16 | |
w.setframerate(sr) | |
w.writeframes(pcm) | |
def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict: | |
""" | |
1) Decode chunk to mono 16k PCM. | |
2) Run VAD to locate head/tail silence. | |
3) Trim only if head or tail >= 10s. | |
4) Save the (possibly trimmed) WAV to local file. | |
5) Return timing metadata, including 'trimmed_start_ms' to preserve global timestamps. | |
""" | |
# 0) Names & constants | |
sr = 16000 | |
bpf = 2 | |
samples_per_ms = sr // 1000 | |
def bytes_from_ms(ms: int) -> int: | |
return int(ms * samples_per_ms) * bpf | |
ch = task["channel"] | |
ck = task["chunk"] | |
job = task.get("job_id", "job") | |
idx = str(ck["idx"]) | |
# 1) Decode chunk | |
pcm = _decode_chunk_to_pcm(task) | |
planned_dur_ms = int(ck["dur_ms"]) | |
# 2) VAD head/tail detection | |
first_ms, last_ms = _find_head_tail_speech_ms(pcm, sr=sr) | |
head_sil_ms = int(first_ms) if first_ms is not None else planned_dur_ms | |
tail_sil_ms = int(planned_dur_ms - last_ms) if last_ms is not None else planned_dur_ms | |
# 3) Decide trimming (only if head or tail >= 10s) | |
trim_applied = False | |
eff_start_ms = 0 | |
eff_end_ms = planned_dur_ms | |
trimmed_pcm = pcm | |
if (head_sil_ms >= TRIM_THRESHOLD_MS) or (tail_sil_ms >= TRIM_THRESHOLD_MS): | |
# If no speech found at all, mark skip | |
if first_ms is None or last_ms is None or last_ms <= first_ms: | |
out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_nospeech.wav") | |
_write_wav(out_wav_path, b"", sr) | |
return { | |
"out_wav_path": out_wav_path, | |
"sr": sr, | |
"trim_applied": False, | |
"trimmed_start_ms": 0, | |
"head_silence_ms": head_sil_ms, | |
"tail_silence_ms": tail_sil_ms, | |
"effective_start_ms": 0, | |
"effective_dur_ms": 0, | |
"abs_start_ms": ck["global_offset_ms"], | |
"chunk_idx": idx, | |
"channel": ch, | |
"skip": True, | |
} | |
# Apply padding & slice | |
start_ms = max(0, int(first_ms) - DEFAULT_PAD_MS) | |
end_ms = min(planned_dur_ms, int(last_ms) + DEFAULT_PAD_MS) | |
if end_ms > start_ms: | |
eff_start_ms = start_ms | |
eff_end_ms = end_ms | |
trimmed_pcm = pcm[bytes_from_ms(start_ms) : bytes_from_ms(end_ms)] | |
trim_applied = True | |
# 4) Write WAV to local file (trimmed or original) | |
tag = "trim" if trim_applied else "full" | |
out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_{tag}.wav") | |
_write_wav(out_wav_path, trimmed_pcm, sr) | |
# 5) Return metadata | |
return { | |
"out_wav_path": out_wav_path, | |
"sr": sr, | |
"trim_applied": trim_applied, | |
"trimmed_start_ms": eff_start_ms if trim_applied else 0, | |
"head_silence_ms": head_sil_ms, | |
"tail_silence_ms": tail_sil_ms, | |
"effective_start_ms": eff_start_ms, | |
"effective_dur_ms": eff_end_ms - eff_start_ms, | |
"abs_start_ms": int(ck["global_offset_ms"]) + eff_start_ms, | |
"chunk_idx": idx, | |
"channel": ch, | |
"job_id": job, | |
"skip": False if (trim_applied or len(pcm) > 0) else True, | |
} | |
# Download once; later runs are instant | |
snapshot_download( | |
repo_id=MODEL_REPO, | |
local_dir=LOCAL_DIR, | |
local_dir_use_symlinks=True, # saves disk space | |
resume_download=True | |
) | |
model_cache_path = LOCAL_DIR # <ββ this is what we pass to WhisperModel | |
# Lazy global holder ---------------------------------------------------------- | |
_whisper = None | |
_batched_whisper = None | |
_diarizer = None | |
_embedder = None | |
# Create global diarization pipeline | |
try: | |
print("Loading diarization model...") | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.set_float32_matmul_precision('high') | |
_diarizer = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=os.getenv("HF_TOKEN"), | |
).to(torch.device("cuda")) | |
print("Diarization model loaded successfully") | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
print(f"Could not load diarization model: {e}") | |
_diarizer = None | |
# GPU is guaranteed to exist *inside* this function | |
def _load_models(): | |
global _whisper, _batched_whisper, _diarizer | |
if _whisper is None: | |
print("Loading Whisper model...") | |
_whisper = WhisperModel( | |
model_cache_path, | |
device="cuda", | |
compute_type="float16", | |
) | |
# Create batched inference pipeline for improved performance | |
_batched_whisper = BatchedInferencePipeline(model=_whisper) | |
print("Whisper model and batched pipeline loaded successfully") | |
return _whisper, _batched_whisper, _diarizer | |
# ----------------------------------------------------------------------------- | |
class WhisperTranscriber: | |
def __init__(self): | |
# do **not** create the models here! | |
pass | |
def preprocess_from_task_json(self, task_json: str) -> dict: | |
"""Parse task JSON and run prepare_and_save_audio_for_model, returning metadata.""" | |
try: | |
task = json.loads(task_json) | |
except Exception as e: | |
raise RuntimeError(f"Invalid JSON: {e}") | |
out_dir = os.path.join(CACHE_ROOT, "preprocessed") | |
os.makedirs(out_dir, exist_ok=True) | |
meta = prepare_and_save_audio_for_model(task, out_dir) | |
return meta | |
# each call gets a GPU slice | |
def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0): | |
"""Transcribe the entire audio file without speaker diarization using batched inference""" | |
whisper, batched_whisper, _ = _load_models() # models live on the GPU | |
print(f"Transcribing full audio with batch size {batch_size}...") | |
start_time = time.time() | |
# Prepare options for batched inference | |
options = dict( | |
language=language, | |
beam_size=5, | |
vad_filter=True, # VAD is enabled by default for batched transcription | |
vad_parameters=VadOptions( | |
max_speech_duration_s=whisper.feature_extractor.chunk_length, | |
min_speech_duration_ms=150, # ignore ultra-short blips | |
min_silence_duration_ms=150, # split on short Mandarin pauses (if supported) speech_pad_ms=100, | |
threshold=0.25, | |
neg_threshold=0.2, | |
), | |
word_timestamps=True, | |
initial_prompt=prompt, | |
condition_on_previous_text=False, # avoid runaway context | |
language_detection_segments=1, | |
task="translate" if translate else "transcribe", | |
) | |
if batch_size > 1: | |
# Use batched inference for better performance | |
segments, transcript_info = batched_whisper.transcribe( | |
audio_path, | |
batch_size=batch_size, | |
**options | |
) | |
else: | |
segments, transcript_info = whisper.transcribe( | |
audio_path, | |
**options | |
) | |
segments = list(segments) | |
detected_language = transcript_info.language | |
print("Detected language: ", detected_language, "segments: ", len(segments)) | |
# Process segments | |
results = [] | |
for seg in segments: | |
# Create result entry with detailed format | |
words_list = [] | |
if seg.words: | |
for word in seg.words: | |
words_list.append({ | |
"start": float(word.start) + float(base_offset_s), | |
"end": float(word.end) + float(base_offset_s), | |
"word": word.word, | |
"probability": word.probability, | |
"speaker": "SPEAKER_00" # No speaker identification in full transcription | |
}) | |
results.append({ | |
"start": float(seg.start) + float(base_offset_s), | |
"end": float(seg.end) + float(base_offset_s), | |
"text": seg.text, | |
"speaker": "SPEAKER_00", # Single speaker assumption | |
"avg_logprob": seg.avg_logprob, | |
"words": words_list, | |
"duration": float(seg.end - seg.start) | |
}) | |
transcription_time = time.time() - start_time | |
print(f"Full audio transcribed in {transcription_time:.2f} seconds using batch size {batch_size}") | |
#print(results) | |
return results, detected_language | |
# Removed audio cutting; transcription is done once on the full (preprocessed) audio | |
# each call gets a GPU slice | |
# Removed segment-wise transcription; using single full-audio transcription | |
# each call gets a GPU slice | |
def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0): | |
"""Perform speaker diarization; return segments with global timestamps and per-speaker embeddings.""" | |
_, _, diarizer = _load_models() # models live on the GPU | |
if diarizer is None: | |
print("Diarization model not available, creating single speaker segment") | |
# Load audio to get duration | |
waveform, sample_rate = torchaudio.load(audio_path) | |
duration = waveform.shape[1] / sample_rate | |
# Try to compute a single-speaker embedding | |
speaker_embeddings = {} | |
try: | |
embedder = self._load_embedder() | |
# Provide waveform as (channel, time) and pad if too short | |
min_embed_duration_sec = 3.0 | |
min_samples = int(min_embed_duration_sec * sample_rate) | |
if waveform.shape[1] < min_samples: | |
pad_len = min_samples - waveform.shape[1] | |
pad = torch.zeros(waveform.shape[0], pad_len, dtype=waveform.dtype, device=waveform.device) | |
waveform = torch.cat([waveform, pad], dim=1) | |
emb = embedder({"waveform": waveform, "sample_rate": sample_rate}) | |
speaker_embeddings["SPEAKER_00"] = emb.squeeze().tolist() | |
except Exception: | |
pass | |
return [{ | |
"start": 0.0 + float(base_offset_s), | |
"end": duration + float(base_offset_s), | |
"speaker": "SPEAKER_00" | |
}], 1, speaker_embeddings | |
print("Starting diarization...") | |
start_time = time.time() | |
# Load audio for diarization | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Perform diarization | |
diarization = diarizer( | |
{"waveform": waveform, "sample_rate": sample_rate}, | |
num_speakers=num_speakers, | |
) | |
# Convert to list format | |
diarize_segments = [] | |
diarization_list = list(diarization.itertracks(yield_label=True)) | |
#print(diarization_list) | |
for turn, _, speaker in diarization_list: | |
diarize_segments.append({ | |
"start": float(turn.start) + float(base_offset_s), | |
"end": float(turn.end) + float(base_offset_s), | |
"speaker": speaker | |
}) | |
unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]} | |
detected_num_speakers = len(unique_speakers) | |
# Compute per-speaker embeddings by averaging segment embeddings | |
speaker_embeddings = {} | |
try: | |
embedder = self._load_embedder() | |
spk_to_embs = {spk: [] for spk in unique_speakers} | |
# Primary path: slice in-memory waveform and zero-pad short segments | |
min_embed_duration_sec = 3.0 | |
audio_duration_sec = float(waveform.shape[1]) / float(sample_rate) | |
for turn, _, speaker in diarization_list: | |
seg_start = float(turn.start) | |
seg_end = float(turn.end) | |
if seg_end <= seg_start: | |
continue | |
start_sample = max(0, int(seg_start * sample_rate)) | |
end_sample = min(waveform.shape[1], int(seg_end * sample_rate)) | |
if end_sample <= start_sample: | |
continue | |
seg_wav = waveform[:, start_sample:end_sample].contiguous() | |
min_samples = int(min_embed_duration_sec * sample_rate) | |
if seg_wav.shape[1] < min_samples: | |
pad_len = min_samples - seg_wav.shape[1] | |
pad = torch.zeros(seg_wav.shape[0], pad_len, dtype=seg_wav.dtype, device=seg_wav.device) | |
seg_wav = torch.cat([seg_wav, pad], dim=1) | |
try: | |
emb = embedder({"waveform": seg_wav, "sample_rate": sample_rate}) | |
except Exception: | |
# Fallback: use crop on the file with expanded window to minimum duration | |
desired_end = min(seg_start + min_embed_duration_sec, audio_duration_sec) | |
desired_start = max(0.0, desired_end - min_embed_duration_sec) | |
emb = embedder.crop(audio_path, Segment(desired_start, desired_end)) | |
spk_to_embs[speaker].append(emb.squeeze()) | |
# average | |
for spk, embs in spk_to_embs.items(): | |
if len(embs) == 0: | |
continue | |
# stack and mean | |
try: | |
import torch as _torch | |
embs_tensor = _torch.stack([_torch.as_tensor(e) for e in embs], dim=0) | |
centroid = embs_tensor.mean(dim=0) | |
# L2 normalize | |
centroid = centroid / (centroid.norm(p=2) + 1e-12) | |
speaker_embeddings[spk] = centroid.cpu().tolist() | |
except Exception: | |
# fallback to first embedding | |
speaker_embeddings[spk] = embs[0].cpu().tolist() | |
#print(speaker_embeddings[spk]) | |
except Exception as e: | |
print(f"Error during embedding calculation: {e}") | |
print(f"Diarization segments: {diarize_segments}") | |
pass | |
diarization_time = time.time() - start_time | |
print(f"Diarization completed in {diarization_time:.2f} seconds") | |
return diarize_segments, detected_num_speakers, speaker_embeddings | |
def _load_embedder(self): | |
"""Lazy-load speaker embedding inference model on GPU.""" | |
global _embedder | |
if _embedder is None: | |
# window="whole" to compute one embedding per provided chunk | |
token = os.getenv("HF_TOKEN") | |
model = Model.from_pretrained("pyannote/embedding", use_auth_token=token) | |
_embedder = Inference(model, window="whole", device=torch.device("cuda")) | |
return _embedder | |
def assign_speakers_to_transcription(self, transcription_results, diarization_segments): | |
"""Assign speakers to words and segments based on overlap with diarization segments.""" | |
if not diarization_segments: | |
return transcription_results | |
# simple helper to find speaker at given time | |
def speaker_at(t: float): | |
for seg in diarization_segments: | |
if seg["start"] <= t < seg["end"]: | |
return seg["speaker"] | |
# if not inside, return closest segment's speaker | |
closest = None | |
best = float("inf") | |
for seg in diarization_segments: | |
if t < seg["start"]: | |
d = seg["start"] - t | |
elif t > seg["end"]: | |
d = t - seg["end"] | |
else: | |
d = 0.0 | |
if d < best: | |
best = d | |
closest = seg | |
return closest["speaker"] if closest else "SPEAKER_00" | |
for seg in transcription_results: | |
# Assign per-word speakers | |
if seg.get("words"): | |
speaker_counts = {} | |
for w in seg["words"]: | |
mid = (float(w["start"]) + float(w["end"])) / 2.0 | |
spk = speaker_at(mid) | |
w["speaker"] = spk | |
speaker_counts[spk] = speaker_counts.get(spk, 0) + (float(w["end"]) - float(w["start"])) | |
# Segment speaker = speaker with max accumulated word duration | |
if speaker_counts: | |
seg["speaker"] = max(speaker_counts.items(), key=lambda kv: kv[1])[0] | |
else: | |
mid = (float(seg["start"]) + float(seg["end"])) / 2.0 | |
seg["speaker"] = speaker_at(mid) | |
return transcription_results | |
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0): | |
"""Group consecutive segments from the same speaker""" | |
if not segments: | |
return segments | |
grouped_segments = [] | |
current_group = segments[0].copy() | |
sentence_end_pattern = r"[.!?]+" | |
for segment in segments[1:]: | |
time_gap = segment["start"] - current_group["end"] | |
current_duration = current_group["end"] - current_group["start"] | |
# Conditions for combining segments | |
can_combine = ( | |
segment["speaker"] == current_group["speaker"] and | |
time_gap <= max_gap and | |
current_duration < max_duration and | |
not re.search(sentence_end_pattern, current_group["text"][-1:]) | |
) | |
if can_combine: | |
# Merge segments | |
current_group["end"] = segment["end"] | |
current_group["text"] += " " + segment["text"] | |
current_group["words"].extend(segment["words"]) | |
current_group["duration"] = current_group["end"] - current_group["start"] | |
else: | |
# Start new group | |
grouped_segments.append(current_group) | |
current_group = segment.copy() | |
grouped_segments.append(current_group) | |
# Clean up text | |
for segment in grouped_segments: | |
segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip() | |
segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"]) | |
return grouped_segments | |
# each call gets a GPU slice | |
def process_audio_full(self, task_json, language=None, translate=False, prompt=None, group_segments=True, batch_size=16): | |
"""Process a single chunk using task JSON (no diarization).""" | |
if not task_json or not str(task_json).strip(): | |
return {"error": "No JSON provided"} | |
pre_meta = None | |
try: | |
print("Starting full transcription pipeline...") | |
# Step 1: Preprocess per chunk JSON | |
print("Preprocessing chunk JSON...") | |
pre_meta = self.preprocess_from_task_json(task_json) | |
if pre_meta.get("skip"): | |
return {"segments": [], "language": "unknown", "num_speakers": 1, "transcription_method": "full_audio_batched", "batch_size": batch_size} | |
wav_path = pre_meta["out_wav_path"] | |
# Adjust timestamps by trimmed_start_ms: abs_start_ms is already global start for saved file | |
base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0 | |
# Step 2: Transcribe the entire audio with batching | |
transcription_results, detected_language = self.transcribe_full_audio( | |
wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s | |
) | |
# Step 3: Group segments if requested (based on time gaps and sentence endings) | |
if group_segments: | |
transcription_results = self.group_segments_by_speaker(transcription_results) | |
# Step 4: Return results | |
return { | |
"segments": transcription_results, | |
"language": detected_language, | |
"num_speakers": 1, # Single speaker assumption | |
"transcription_method": "full_audio_batched", | |
"batch_size": batch_size | |
} | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
return {"error": f"Processing failed: {str(e)}"} | |
finally: | |
# Clean up preprocessed wav | |
if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]): | |
try: | |
os.unlink(pre_meta["out_wav_path"]) | |
except Exception: | |
pass | |
# each call gets a GPU slice | |
def process_audio(self, task_json, num_speakers=None, language=None, | |
translate=False, prompt=None, group_segments=True, batch_size=8): | |
"""Main processing function with diarization using task JSON for a single chunk. | |
Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription. | |
""" | |
if not task_json or not str(task_json).strip(): | |
return {"error": "No JSON provided"} | |
pre_meta = None | |
try: | |
print("Starting new processing pipeline...") | |
# Step 1: Preprocess per chunk JSON | |
print("Preprocessing chunk JSON...") | |
pre_meta = self.preprocess_from_task_json(task_json) | |
if pre_meta.get("skip"): | |
return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size} | |
wav_path = pre_meta["out_wav_path"] | |
base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0 | |
# Step 2: Transcribe full audio once | |
transcription_results, detected_language = self.transcribe_full_audio( | |
wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s | |
) | |
# Step 3: Perform diarization with global offset | |
diarization_segments, detected_num_speakers, speaker_embeddings = self.perform_diarization( | |
wav_path, num_speakers, base_offset_s=base_offset_s | |
) | |
# Step 4: Merge diarization into transcription (assign speakers) | |
transcription_results = self.assign_speakers_to_transcription(transcription_results, diarization_segments) | |
# Step 5: Group segments if requested | |
if group_segments: | |
transcription_results = self.group_segments_by_speaker(transcription_results) | |
# Step 6: Return results | |
result = { | |
"segments": transcription_results, | |
"language": detected_language, | |
"num_speakers": detected_num_speakers, | |
"transcription_method": "diarized_segments_batched", | |
"batch_size": batch_size, | |
"speaker_embeddings": speaker_embeddings, | |
} | |
job_id = pre_meta["job_id"] | |
task_id = pre_meta["chunk_idx"] | |
filekey = f"ai-transcribe/split/{job_id}-{task_id}.json" | |
ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey) | |
if ret: | |
return {"filekey": filekey} | |
else: | |
return {"error": "Failed to upload to R2"} | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
return {"error": f"Processing failed: {str(e)}"} | |
finally: | |
# Clean up preprocessed wav | |
if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]): | |
try: | |
os.unlink(pre_meta["out_wav_path"]) | |
except Exception: | |
pass | |
# Initialize transcriber | |
transcriber = WhisperTranscriber() | |
def format_segments_for_display(result): | |
"""Format segments for display in Gradio""" | |
if "error" in result: | |
return f"β Error: {result['error']}" | |
segments = result.get("segments", []) | |
language = result.get("language", "unknown") | |
num_speakers = result.get("num_speakers", 1) | |
method = result.get("transcription_method", "unknown") | |
batch_size = result.get("batch_size", "N/A") | |
output = f"π― **Detection Results:**\n" | |
output += f"- Language: {language}\n" | |
output += f"- Speakers: {num_speakers}\n" | |
output += f"- Segments: {len(segments)}\n" | |
output += f"- Method: {method}\n" | |
output += f"- Batch Size: {batch_size}\n\n" | |
output += "π **Transcription:**\n\n" | |
for i, segment in enumerate(segments, 1): | |
start_time = str(datetime.timedelta(seconds=int(segment["start"]))) | |
end_time = str(datetime.timedelta(seconds=int(segment["end"]))) | |
speaker = segment.get("speaker", "SPEAKER_00") | |
text = segment["text"] | |
output += f"**{speaker}** ({start_time} β {end_time})\n" | |
output += f"{text}\n\n" | |
return output | |
def process_audio_gradio(task_json, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size): | |
"""Gradio interface function""" | |
if use_diarization: | |
result = transcriber.process_audio( | |
task_json=task_json, | |
num_speakers=num_speakers if num_speakers > 0 else None, | |
language=language if language != "auto" else None, | |
translate=translate, | |
prompt=prompt if prompt and prompt.strip() else None, | |
group_segments=group_segments, | |
batch_size=batch_size | |
) | |
else: | |
result = transcriber.process_audio_full( | |
task_json=task_json, | |
language=language if language != "auto" else None, | |
translate=translate, | |
prompt=prompt if prompt and prompt.strip() else None, | |
group_segments=group_segments, | |
batch_size=batch_size | |
) | |
formatted_output = format_segments_for_display(result) | |
return formatted_output, result | |
# Create Gradio interface | |
demo = gr.Blocks( | |
title="ποΈ Whisper Transcription with Speaker Diarization", | |
theme="default" | |
) | |
with demo: | |
gr.Markdown(""" | |
# ποΈ Advanced Audio Transcription & Speaker Diarization | |
Upload an audio file to get accurate transcription with speaker identification, powered by: | |
- **Faster-Whisper Large V3 Turbo** with batched inference for optimal performance | |
- **Pyannote 3.1** for speaker diarization | |
- **ZeroGPU** acceleration for optimal performance | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
task_json_input = gr.Textbox( | |
label="π§Ύ Paste Task JSON", | |
placeholder="Paste the per-chunk task JSON here...", | |
lines=16, | |
) | |
with gr.Accordion("βοΈ Advanced Settings", open=False): | |
use_diarization = gr.Checkbox( | |
label="Enable Speaker Diarization", | |
value=True, | |
info="Uncheck for faster transcription without speaker identification" | |
) | |
batch_size = gr.Slider( | |
minimum=1, | |
maximum=128, | |
value=16, | |
step=1, | |
label="Batch Size", | |
info="Higher values = faster processing but more GPU memory usage. Recommended: 8-24" | |
) | |
num_speakers = gr.Slider( | |
minimum=0, | |
maximum=20, | |
value=0, | |
step=1, | |
label="Number of Speakers (0 = auto-detect)", | |
visible=True | |
) | |
language = gr.Dropdown( | |
choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
value="auto", | |
label="Language" | |
) | |
translate = gr.Checkbox( | |
label="Translate to English", | |
value=False | |
) | |
prompt = gr.Textbox( | |
label="Vocabulary Prompt (names, acronyms, etc.)", | |
placeholder="Enter names, technical terms, or context...", | |
lines=2 | |
) | |
group_segments = gr.Checkbox( | |
label="Group segments by speaker/time", | |
value=True | |
) | |
process_btn = gr.Button("π Transcribe Audio", variant="primary") | |
with gr.Column(): | |
output_text = gr.Markdown( | |
label="π Transcription Results", | |
value="Paste task JSON and click 'Transcribe Audio' to get started!" | |
) | |
output_json = gr.JSON( | |
label="π§ Raw Output (JSON)", | |
visible=False | |
) | |
# Update visibility of num_speakers based on diarization toggle | |
use_diarization.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=[use_diarization], | |
outputs=[num_speakers] | |
) | |
# Event handlers | |
process_btn.click( | |
fn=process_audio_gradio, | |
inputs=[ | |
task_json_input, | |
num_speakers, | |
language, | |
translate, | |
prompt, | |
group_segments, | |
use_diarization, | |
batch_size | |
], | |
outputs=[output_text, output_json] | |
) | |
# Examples | |
gr.Markdown("### π Usage Tips:") | |
gr.Markdown(""" | |
- Paste a single-chunk task JSON matching the preprocess schema | |
- Batch Size: Higher values (16-24) = faster but uses more GPU memory | |
- Speaker diarization: Enable for speaker identification (slower) | |
- Languages: Supports 100+ languages with auto-detection | |
- Vocabulary: Add names and technical terms in the prompt for better accuracy | |
""") | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |