liuyang
add space decoration
233e4b4
raw
history blame
21.4 kB
import spaces
import gradio as gr
import torch
import torchaudio
import numpy as np
import pandas as pd
import time
import datetime
import re
import subprocess
import os
import tempfile
import spaces
from transformers import pipeline
from pyannote.audio import Pipeline
import requests
import base64
# Install flash attention for acceleration
'''
try:
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
check=True
)
except subprocess.CalledProcessError:
print("Warning: Could not install flash-attn, falling back to default attention")
'''
# Create global pipeline (similar to working HuggingFace example)
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device="cuda",
model_kwargs={"attn_implementation": "flash_attention_2"},
return_timestamps=True,
)
def comprehensive_flash_attention_verification():
"""Comprehensive verification of flash attention setup"""
print("πŸ” Running Flash Attention Verification...")
print("=" * 50)
verification_results = {}
# Check 1: Package Installation
print("πŸ” Checking Python packages...")
try:
import flash_attn
print(f"βœ… flash-attn: {flash_attn.__version__}")
verification_results["flash_attn_installed"] = True
except ImportError:
print("❌ flash-attn: Not installed")
verification_results["flash_attn_installed"] = False
try:
import transformers
print(f"βœ… transformers: {transformers.__version__}")
verification_results["transformers_available"] = True
except ImportError:
print("❌ transformers: Not installed")
verification_results["transformers_available"] = False
# Check 2: CUDA Availability
print("\nπŸ” Checking CUDA availability...")
cuda_available = torch.cuda.is_available()
print(f"βœ… CUDA available: {cuda_available}")
if cuda_available:
print(f"βœ… CUDA version: {torch.version.cuda}")
print(f"βœ… GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f"βœ… GPU {i}: {torch.cuda.get_device_name(i)}")
verification_results["cuda_available"] = cuda_available
# Check 3: Flash Attention Import
print("\nπŸ” Testing flash attention imports...")
try:
from flash_attn import flash_attn_func
print("βœ… flash_attn_func imported successfully")
if flash_attn_func is None:
print("❌ flash_attn_func is None")
verification_results["flash_attn_import"] = False
else:
print("βœ… flash_attn_func is callable")
verification_results["flash_attn_import"] = True
except ImportError as e:
print(f"❌ Import error: {e}")
verification_results["flash_attn_import"] = False
except Exception as e:
print(f"❌ Unexpected error: {e}")
verification_results["flash_attn_import"] = False
# Check 4: Flash Attention Functionality Test
print("\nπŸ” Testing flash attention functionality...")
if not cuda_available:
print("⚠️ Skipping functionality test - CUDA not available")
verification_results["flash_attn_functional"] = False
elif not verification_results.get("flash_attn_import", False):
print("⚠️ Skipping functionality test - Import failed")
verification_results["flash_attn_functional"] = False
else:
try:
from flash_attn import flash_attn_func
# Create small dummy tensors
batch_size, seq_len, num_heads, head_dim = 1, 16, 4, 32
device = "cuda:0"
dtype = torch.float16
print(f"Creating tensors: batch={batch_size}, seq_len={seq_len}, heads={num_heads}, dim={head_dim}")
q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device)
print("βœ… Tensors created successfully")
# Test flash attention
output = flash_attn_func(q, k, v, dropout_p=0.0, causal=False)
print(f"βœ… Flash attention output shape: {output.shape}")
print("βœ… Flash attention test passed!")
verification_results["flash_attn_functional"] = True
except Exception as e:
print(f"❌ Flash attention test failed: {e}")
import traceback
traceback.print_exc()
verification_results["flash_attn_functional"] = False
# Summary
print("\n" + "=" * 50)
print("πŸ“Š VERIFICATION SUMMARY")
print("=" * 50)
all_passed = True
for check_name, result in verification_results.items():
status = "βœ… PASS" if result else "❌ FAIL"
print(f"{check_name}: {status}")
if not result:
all_passed = False
if all_passed:
print("\nπŸŽ‰ All checks passed! Flash attention should work.")
return True
else:
print("\n⚠️ Some checks failed. Flash attention may not work properly.")
print("\nRecommendations:")
print("1. Try reinstalling flash-attn: pip uninstall flash-attn && pip install flash-attn --no-build-isolation")
print("2. Check CUDA compatibility with your PyTorch version")
print("3. Consider using default attention as fallback")
return False
class WhisperTranscriber:
def __init__(self):
self.pipe = pipe # Use global pipeline
self.diarization_model = None
#@spaces.GPU
def setup_models(self):
"""Initialize models with GPU acceleration"""
if self.pipe is None:
print("Loading Whisper model...")
self.pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large-v3-turbo",
torch_dtype=torch.float16,
device="cuda:0",
model_kwargs={"attn_implementation": "flash_attention_2"},
return_timestamps=True,
)
if self.diarization_model is None:
print("Loading diarization model...")
# Note: You'll need to set up authentication for pyannote models
# For demo purposes, we'll handle the case where it's not available
try:
self.diarization_model = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=os.getenv("HF_TOKEN")
).to(torch.device("cuda"))
except Exception as e:
print(f"Could not load diarization model: {e}")
self.diarization_model = None
def convert_audio_format(self, audio_path):
"""Convert audio to 16kHz mono WAV format"""
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_wav_path = temp_wav.name
temp_wav.close()
try:
subprocess.run([
"ffmpeg", "-i", audio_path,
"-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le",
temp_wav_path, "-y"
], check=True, capture_output=True)
return temp_wav_path
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Audio conversion failed: {e}")
@spaces.GPU
def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None):
"""Transcribe audio using Whisper with flash attention"""
# Run comprehensive flash attention verification
#flash_attention_working = comprehensive_flash_attention_verification()
#if not flash_attention_working:
# print("⚠️ Flash attention verification failed, but proceeding with transcription...")
# print("You may encounter the TypeError: 'NoneType' object is not callable error")
'''
#if self.pipe is None:
# self.setup_models()
if next(self.pipe.model.parameters()).device.type != "cuda":
self.pipe.model.to("cuda")
'''
print("Starting transcription...")
start_time = time.time()
# Prepare generation kwargs
generate_kwargs = {}
if language:
generate_kwargs["language"] = language
if translate:
generate_kwargs["task"] = "translate"
if prompt:
generate_kwargs["prompt_ids"] = self.pipe.tokenizer.encode(prompt)
# Transcribe with timestamps
result = self.pipe(
audio_path,
return_timestamps=True,
generate_kwargs=generate_kwargs,
chunk_length_s=30,
batch_size=128,
)
transcription_time = time.time() - start_time
print(f"Transcription completed in {transcription_time:.2f} seconds")
# Extract segments and detected language
segments = []
if "chunks" in result:
for chunk in result["chunks"]:
segment = {
"start": float(chunk["timestamp"][0] or 0),
"end": float(chunk["timestamp"][1] or 0),
"text": chunk["text"].strip(),
}
segments.append(segment)
else:
# Fallback for different result format
segments = [{
"start": 0.0,
"end": 0.0,
"text": result["text"]
}]
detected_language = getattr(result, 'language', language or 'unknown')
transcription_time = time.time() - start_time
print(f"Transcription parse completed in {transcription_time:.2f} seconds")
return segments, detected_language
def perform_diarization(self, audio_path, num_speakers=None):
"""Perform speaker diarization"""
if self.diarization_model is None:
print("Diarization model not available, assigning single speaker")
return [], 1
print("Starting diarization...")
start_time = time.time()
# Load audio for diarization
waveform, sample_rate = torchaudio.load(audio_path)
# Perform diarization
diarization = self.diarization_model(
{"waveform": waveform, "sample_rate": sample_rate},
num_speakers=num_speakers,
)
# Convert to list format
diarize_segments = []
diarization_list = list(diarization.itertracks(yield_label=True))
for turn, _, speaker in diarization_list:
diarize_segments.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker
})
unique_speakers = {speaker for _, _, speaker in diarization_list}
detected_num_speakers = len(unique_speakers)
diarization_time = time.time() - start_time
print(f"Diarization completed in {diarization_time:.2f} seconds")
return diarize_segments, detected_num_speakers
def merge_transcription_and_diarization(self, transcription_segments, diarization_segments):
"""Merge transcription segments with speaker information"""
if not diarization_segments:
# No diarization available, assign single speaker
for segment in transcription_segments:
segment["speaker"] = "SPEAKER_00"
return transcription_segments
print("Merging transcription and diarization...")
diarize_df = pd.DataFrame(diarization_segments)
final_segments = []
for segment in transcription_segments:
# Calculate intersection with diarization segments
diarize_df["intersection"] = np.maximum(0,
np.minimum(diarize_df["end"], segment["end"]) -
np.maximum(diarize_df["start"], segment["start"])
)
# Find speaker with maximum intersection
dia_tmp = diarize_df[diarize_df["intersection"] > 0]
if len(dia_tmp) > 0:
speaker = (
dia_tmp.groupby("speaker")["intersection"]
.sum()
.sort_values(ascending=False)
.index[0]
)
else:
speaker = "SPEAKER_00"
segment["speaker"] = speaker
segment["duration"] = segment["end"] - segment["start"]
final_segments.append(segment)
return final_segments
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
"""Group consecutive segments from the same speaker"""
if not segments:
return segments
grouped_segments = []
current_group = segments[0].copy()
sentence_end_pattern = r"[.!?]+\s*$"
for segment in segments[1:]:
time_gap = segment["start"] - current_group["end"]
current_duration = current_group["end"] - current_group["start"]
# Conditions for combining segments
can_combine = (
segment["speaker"] == current_group["speaker"] and
time_gap <= max_gap and
current_duration < max_duration and
not re.search(sentence_end_pattern, current_group["text"])
)
if can_combine:
# Merge segments
current_group["end"] = segment["end"]
current_group["text"] += " " + segment["text"]
current_group["duration"] = current_group["end"] - current_group["start"]
else:
# Start new group
grouped_segments.append(current_group)
current_group = segment.copy()
grouped_segments.append(current_group)
# Clean up text
for segment in grouped_segments:
segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip()
segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"])
return grouped_segments
@spaces.GPU
def process_audio(self, audio_file, num_speakers=None, language=None,
translate=False, prompt=None, group_segments=True):
"""Main processing function"""
if audio_file is None:
return {"error": "No audio file provided"}
try:
# Setup models if not already done
#self.setup_models()
# Convert audio format
#wav_path = self.convert_audio_format(audio_file)
try:
# Transcribe audio
transcription_segments, detected_language = self.transcribe_audio(
audio_file, language, translate, prompt
)
# Perform diarization
diarization_segments, detected_num_speakers = self.perform_diarization(
audio_file, num_speakers
)
# Merge transcription and diarization
final_segments = self.merge_transcription_and_diarization(
transcription_segments, diarization_segments
)
# Group segments if requested
if group_segments:
final_segments = self.group_segments_by_speaker(final_segments)
return {
"segments": final_segments,
"language": detected_language,
"num_speakers": detected_num_speakers or 1,
"total_segments": len(final_segments)
}
finally:
# Clean up temporary file
if os.path.exists(audio_file):
os.unlink(audio_file)
except Exception as e:
import traceback
traceback.print_exc()
return {"error": f"Processing failed: {str(e)}"}
# Initialize transcriber
transcriber = WhisperTranscriber()
def format_segments_for_display(result):
"""Format segments for display in Gradio"""
if "error" in result:
return f"❌ Error: {result['error']}"
segments = result.get("segments", [])
language = result.get("language", "unknown")
num_speakers = result.get("num_speakers", 1)
output = f"🎯 **Detection Results:**\n"
output += f"- Language: {language}\n"
output += f"- Speakers: {num_speakers}\n"
output += f"- Segments: {len(segments)}\n\n"
output += "πŸ“ **Transcription:**\n\n"
for i, segment in enumerate(segments, 1):
start_time = str(datetime.timedelta(seconds=int(segment["start"])))
end_time = str(datetime.timedelta(seconds=int(segment["end"])))
speaker = segment.get("speaker", "SPEAKER_00")
text = segment["text"]
output += f"**{speaker}** ({start_time} β†’ {end_time})\n"
output += f"{text}\n\n"
return output
@spaces.GPU
def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments):
"""Gradio interface function"""
result = transcriber.process_audio(
audio_file=audio_file,
num_speakers=num_speakers if num_speakers > 0 else None,
language=language if language != "auto" else None,
translate=translate,
prompt=prompt if prompt and prompt.strip() else None,
group_segments=group_segments
)
formatted_output = format_segments_for_display(result)
return formatted_output, result
# Create Gradio interface
demo = gr.Blocks(
title="πŸŽ™οΈ Whisper Transcription with Speaker Diarization",
theme="default"
)
with demo:
gr.Markdown("""
# πŸŽ™οΈ Advanced Audio Transcription & Speaker Diarization
Upload an audio file to get accurate transcription with speaker identification, powered by:
- **Whisper Large V3 Turbo** with Flash Attention for fast transcription
- **Pyannote 3.1** for speaker diarization
- **ZeroGPU** acceleration for optimal performance
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="🎡 Upload Audio File",
type="filepath",
#source="upload"
)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
num_speakers = gr.Slider(
minimum=0,
maximum=20,
value=0,
step=1,
label="Number of Speakers (0 = auto-detect)"
)
language = gr.Dropdown(
choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"],
value="auto",
label="Language"
)
translate = gr.Checkbox(
label="Translate to English",
value=False
)
prompt = gr.Textbox(
label="Vocabulary Prompt (names, acronyms, etc.)",
placeholder="Enter names, technical terms, or context...",
lines=2
)
group_segments = gr.Checkbox(
label="Group segments by speaker",
value=True
)
process_btn = gr.Button("πŸš€ Transcribe Audio", variant="primary")
with gr.Column():
output_text = gr.Markdown(
label="πŸ“ Transcription Results",
value="Upload an audio file and click 'Transcribe Audio' to get started!"
)
output_json = gr.JSON(
label="πŸ”§ Raw Output (JSON)",
visible=False
)
# Event handlers
process_btn.click(
fn=process_audio_gradio,
inputs=[
audio_input,
num_speakers,
language,
translate,
prompt,
group_segments
],
outputs=[output_text, output_json]
)
# Examples
gr.Markdown("### πŸ“‹ Usage Tips:")
gr.Markdown("""
- **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more
- **Max duration**: Recommended under 10 minutes for optimal performance
- **Speaker detection**: Works best with clear, distinct voices
- **Languages**: Supports 100+ languages with auto-detection
- **Vocabulary**: Add names and technical terms in the prompt for better accuracy
""")
if __name__ == "__main__":
demo.launch(debug=True)