liuyang
library
a1d23de
raw
history blame
18.1 kB
import spaces
import os, pathlib
CACHE_ROOT = "/home/user/app/cache" # any folder you own
os.environ.update(
TORCH_HOME = f"{CACHE_ROOT}/torch",
XDG_CACHE_HOME = f"{CACHE_ROOT}/xdg", # torch fallback
PYANNOTE_CACHE = f"{CACHE_ROOT}/pyannote",
HF_HOME = f"{CACHE_ROOT}/huggingface",
TRANSFORMERS_CACHE= f"{CACHE_ROOT}/transformers",
MPLCONFIGDIR = f"{CACHE_ROOT}/mpl",
)
# make sure the directories exist
for path in os.environ.values():
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
import gradio as gr
import torch
import torchaudio
import numpy as np
import pandas as pd
import time
import datetime
import re
import subprocess
import os
import tempfile
import spaces
from faster_whisper import WhisperModel
from faster_whisper.vad import VadOptions
import requests
import base64
from pyannote.audio import Pipeline
import os, importlib.util, pathlib, site
spec = importlib.util.find_spec("nvidia.cudnn")
if spec is not None: # wheel is installed
cudnn_dir = pathlib.Path(spec.origin).parent / "lib"
os.environ["LD_LIBRARY_PATH"] = (
f"{cudnn_dir}:{os.environ.get('LD_LIBRARY_PATH','')}"
)
# Lazy global holder ----------------------------------------------------------
_whisper = None
_diarizer = None
# Create global diarization pipeline
try:
print("Loading diarization model...")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')
_diarizer = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.getenv("HF_TOKEN"),
torch_dtype=torch.float16,
).to(torch.device("cuda"))
_diarizer.model.half() # FP16
for m in _diarizer.model.modules(): # compact LSTM weights
if isinstance(m, torch.nn.LSTM):
m.flatten_parameters()
_diarizer.model = torch.compile(_diarizer.model, mode="reduce-overhead")
print("Diarization model loaded successfully")
except Exception as e:
import traceback
traceback.print_exc()
print(f"Could not load diarization model: {e}")
_diarizer = None
@spaces.GPU # GPU is guaranteed to exist *inside* this function
def _load_models():
global _whisper, _diarizer
if _whisper is None:
print("Loading Whisper model...")
_whisper = WhisperModel(
"large-v3-turbo",
device="cuda",
compute_type="float16",
)
print("Whisper model loaded successfully")
return _whisper, _diarizer
# -----------------------------------------------------------------------------
class WhisperTranscriber:
def __init__(self):
# do **not** create the models here!
pass
def convert_audio_format(self, audio_path):
"""Convert audio to 16kHz mono WAV format"""
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_wav_path = temp_wav.name
temp_wav.close()
try:
subprocess.run([
"ffmpeg", "-i", audio_path,
"-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le",
temp_wav_path, "-y"
], check=True, capture_output=True)
return temp_wav_path
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Audio conversion failed: {e}")
def cut_audio_segments(self, audio_path, diarization_segments):
"""Cut audio into segments based on diarization results"""
print("Cutting audio into segments...")
# Load the full audio
waveform, sample_rate = torchaudio.load(audio_path)
audio_segments = []
for segment in diarization_segments:
start_sample = int(segment["start"] * sample_rate)
end_sample = int(segment["end"] * sample_rate)
# Extract the segment
segment_waveform = waveform[:, start_sample:end_sample]
# Create temporary file for this segment
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_file.close()
# Save the segment
torchaudio.save(temp_file.name, segment_waveform, sample_rate)
audio_segments.append({
"audio_path": temp_file.name,
"start": segment["start"],
"end": segment["end"],
"speaker": segment["speaker"]
})
return audio_segments
@spaces.GPU # each call gets a GPU slice
def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
"""Transcribe multiple audio segments using faster_whisper"""
whisper, diarizer = _load_models() # models live on the GPU
print(f"Transcribing {len(audio_segments)} audio segments...")
start_time = time.time()
# Prepare options similar to replicate.py
options = dict(
language=language,
beam_size=5,
vad_filter=True,
vad_parameters=VadOptions(
max_speech_duration_s=whisper.feature_extractor.chunk_length,
min_speech_duration_ms=100,
speech_pad_ms=100,
threshold=0.25,
neg_threshold=0.2,
),
word_timestamps=True,
initial_prompt=prompt,
language_detection_segments=1,
task="translate" if translate else "transcribe",
)
results = []
detected_language = None
for i, segment in enumerate(audio_segments):
print(f"Processing segment {i+1}/{len(audio_segments)}")
# Transcribe this segment
segments, transcript_info = whisper.transcribe(segment["audio_path"], **options)
segments = list(segments)
# Get detected language from first segment
if detected_language is None:
detected_language = transcript_info.language
# Process each transcribed segment
for seg in segments:
# Create result entry with detailed format like replicate.py
words_list = []
if seg.words:
for word in seg.words:
words_list.append({
"start": float(word.start) + segment["start"],
"end": float(word.end) + segment["start"],
"word": word.word,
"probability": word.probability,
"speaker": segment["speaker"]
})
results.append({
"start": float(seg.start) + segment["start"],
"end": float(seg.end) + segment["start"],
"text": seg.text,
"speaker": segment["speaker"],
"avg_logprob": seg.avg_logprob,
"words": words_list,
"duration": float(seg.end - seg.start)
})
# Clean up temporary files
for segment in audio_segments:
if os.path.exists(segment["audio_path"]):
os.unlink(segment["audio_path"])
transcription_time = time.time() - start_time
print(f"All segments transcribed in {transcription_time:.2f} seconds")
return results, detected_language
@spaces.GPU # each call gets a GPU slice
def perform_diarization(self, audio_path, num_speakers=None):
"""Perform speaker diarization"""
whisper, diarizer = _load_models() # models live on the GPU
if diarizer is None:
print("Diarization model not available, creating single speaker segment")
# Load audio to get duration
waveform, sample_rate = torchaudio.load(audio_path)
duration = waveform.shape[1] / sample_rate
return [{
"start": 0.0,
"end": duration,
"speaker": "SPEAKER_00"
}], 1
print("Starting diarization...")
start_time = time.time()
# Load audio for diarization
waveform, sample_rate = torchaudio.load(audio_path)
# Perform diarization
diarization = diarizer(
{"waveform": waveform, "sample_rate": sample_rate},
num_speakers=num_speakers,
)
# Convert to list format
diarize_segments = []
diarization_list = list(diarization.itertracks(yield_label=True))
for turn, _, speaker in diarization_list:
diarize_segments.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker
})
unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]}
detected_num_speakers = len(unique_speakers)
diarization_time = time.time() - start_time
print(f"Diarization completed in {diarization_time:.2f} seconds")
return diarize_segments, detected_num_speakers
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
"""Group consecutive segments from the same speaker"""
if not segments:
return segments
grouped_segments = []
current_group = segments[0].copy()
sentence_end_pattern = r"[.!?]+"
for segment in segments[1:]:
time_gap = segment["start"] - current_group["end"]
current_duration = current_group["end"] - current_group["start"]
# Conditions for combining segments
can_combine = (
segment["speaker"] == current_group["speaker"] and
time_gap <= max_gap and
current_duration < max_duration and
not re.search(sentence_end_pattern, current_group["text"][-1:])
)
if can_combine:
# Merge segments
current_group["end"] = segment["end"]
current_group["text"] += " " + segment["text"]
current_group["words"].extend(segment["words"])
current_group["duration"] = current_group["end"] - current_group["start"]
else:
# Start new group
grouped_segments.append(current_group)
current_group = segment.copy()
grouped_segments.append(current_group)
# Clean up text
for segment in grouped_segments:
segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip()
segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"])
return grouped_segments
@spaces.GPU # each call gets a GPU slice
def process_audio(self, audio_file, num_speakers=None, language=None,
translate=False, prompt=None, group_segments=True):
"""Main processing function - diarization first, then transcription"""
if audio_file is None:
return {"error": "No audio file provided"}
converted_audio_path = None
try:
print("Starting new processing pipeline...")
# Step 1: Convert audio format first
print("Converting audio format...")
converted_audio_path = self.convert_audio_format(audio_file)
# Step 2: Perform diarization on converted audio
diarization_segments, detected_num_speakers = self.perform_diarization(
converted_audio_path, num_speakers
)
# Step 3: Cut audio into segments based on diarization
audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
# Step 4: Transcribe each segment
transcription_results, detected_language = self.transcribe_audio_segments(
audio_segments, language, translate, prompt
)
# Step 5: Group segments if requested
if group_segments:
transcription_results = self.group_segments_by_speaker(transcription_results)
# Step 6: Return in replicate.py format
return {
"segments": transcription_results,
"language": detected_language,
"num_speakers": detected_num_speakers
}
except Exception as e:
import traceback
traceback.print_exc()
return {"error": f"Processing failed: {str(e)}"}
finally:
# Clean up converted audio file
if converted_audio_path and os.path.exists(converted_audio_path):
os.unlink(converted_audio_path)
print("Cleaned up converted audio file")
# Initialize transcriber
transcriber = WhisperTranscriber()
def format_segments_for_display(result):
"""Format segments for display in Gradio"""
if "error" in result:
return f"❌ Error: {result['error']}"
segments = result.get("segments", [])
language = result.get("language", "unknown")
num_speakers = result.get("num_speakers", 1)
output = f"🎯 **Detection Results:**\n"
output += f"- Language: {language}\n"
output += f"- Speakers: {num_speakers}\n"
output += f"- Segments: {len(segments)}\n\n"
output += "πŸ“ **Transcription:**\n\n"
for i, segment in enumerate(segments, 1):
start_time = str(datetime.timedelta(seconds=int(segment["start"])))
end_time = str(datetime.timedelta(seconds=int(segment["end"])))
speaker = segment.get("speaker", "SPEAKER_00")
text = segment["text"]
output += f"**{speaker}** ({start_time} β†’ {end_time})\n"
output += f"{text}\n\n"
return output
@spaces.GPU
def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments):
"""Gradio interface function"""
result = transcriber.process_audio(
audio_file=audio_file,
num_speakers=num_speakers if num_speakers > 0 else None,
language=language if language != "auto" else None,
translate=translate,
prompt=prompt if prompt and prompt.strip() else None,
group_segments=group_segments
)
formatted_output = format_segments_for_display(result)
return formatted_output, result
# Create Gradio interface
demo = gr.Blocks(
title="πŸŽ™οΈ Whisper Transcription with Speaker Diarization",
theme="default"
)
with demo:
gr.Markdown("""
# πŸŽ™οΈ Advanced Audio Transcription & Speaker Diarization
Upload an audio file to get accurate transcription with speaker identification, powered by:
- **Whisper Large V3 Turbo** with Flash Attention for fast transcription
- **Pyannote 3.1** for speaker diarization
- **ZeroGPU** acceleration for optimal performance
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="🎡 Upload Audio File",
type="filepath",
#source="upload"
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
num_speakers = gr.Slider(
minimum=0,
maximum=20,
value=0,
step=1,
label="Number of Speakers (0 = auto-detect)"
)
language = gr.Dropdown(
choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"],
value="auto",
label="Language"
)
translate = gr.Checkbox(
label="Translate to English",
value=False
)
prompt = gr.Textbox(
label="Vocabulary Prompt (names, acronyms, etc.)",
placeholder="Enter names, technical terms, or context...",
lines=2
)
group_segments = gr.Checkbox(
label="Group segments by speaker",
value=True
)
process_btn = gr.Button("πŸš€ Transcribe Audio", variant="primary")
with gr.Column():
output_text = gr.Markdown(
label="πŸ“ Transcription Results",
value="Upload an audio file and click 'Transcribe Audio' to get started!"
)
output_json = gr.JSON(
label="πŸ”§ Raw Output (JSON)",
visible=False
)
# Event handlers
process_btn.click(
fn=process_audio_gradio,
inputs=[
audio_input,
num_speakers,
language,
translate,
prompt,
group_segments
],
outputs=[output_text, output_json]
)
# Examples
gr.Markdown("### πŸ“‹ Usage Tips:")
gr.Markdown("""
- **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more
- **Max duration**: Recommended under 10 minutes for optimal performance
- **Speaker detection**: Works best with clear, distinct voices
- **Languages**: Supports 100+ languages with auto-detection
- **Vocabulary**: Add names and technical terms in the prompt for better accuracy
""")
if __name__ == "__main__":
demo.launch(debug=True)