Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
import torchaudio | |
import numpy as np | |
import pandas as pd | |
import time | |
import datetime | |
import re | |
import subprocess | |
import os | |
import tempfile | |
import spaces | |
from transformers import pipeline | |
from pyannote.audio import Pipeline | |
import requests | |
import base64 | |
# Install flash attention for acceleration | |
''' | |
try: | |
subprocess.run( | |
"pip install flash-attn --no-build-isolation", | |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
shell=True, | |
check=True | |
) | |
except subprocess.CalledProcessError: | |
print("Warning: Could not install flash-attn, falling back to default attention") | |
''' | |
# Create global pipeline (similar to working HuggingFace example) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-large-v3-turbo", | |
torch_dtype=torch.float16, | |
device="cuda", | |
model_kwargs={"attn_implementation": "flash_attention_2"}, | |
return_timestamps=True, | |
) | |
def comprehensive_flash_attention_verification(): | |
"""Comprehensive verification of flash attention setup""" | |
print("π Running Flash Attention Verification...") | |
print("=" * 50) | |
verification_results = {} | |
# Check 1: Package Installation | |
print("π Checking Python packages...") | |
try: | |
import flash_attn | |
print(f"β flash-attn: {flash_attn.__version__}") | |
verification_results["flash_attn_installed"] = True | |
except ImportError: | |
print("β flash-attn: Not installed") | |
verification_results["flash_attn_installed"] = False | |
try: | |
import transformers | |
print(f"β transformers: {transformers.__version__}") | |
verification_results["transformers_available"] = True | |
except ImportError: | |
print("β transformers: Not installed") | |
verification_results["transformers_available"] = False | |
# Check 2: CUDA Availability | |
print("\nπ Checking CUDA availability...") | |
cuda_available = torch.cuda.is_available() | |
print(f"β CUDA available: {cuda_available}") | |
if cuda_available: | |
print(f"β CUDA version: {torch.version.cuda}") | |
print(f"β GPU count: {torch.cuda.device_count()}") | |
for i in range(torch.cuda.device_count()): | |
print(f"β GPU {i}: {torch.cuda.get_device_name(i)}") | |
verification_results["cuda_available"] = cuda_available | |
# Check 3: Flash Attention Import | |
print("\nπ Testing flash attention imports...") | |
try: | |
from flash_attn import flash_attn_func | |
print("β flash_attn_func imported successfully") | |
if flash_attn_func is None: | |
print("β flash_attn_func is None") | |
verification_results["flash_attn_import"] = False | |
else: | |
print("β flash_attn_func is callable") | |
verification_results["flash_attn_import"] = True | |
except ImportError as e: | |
print(f"β Import error: {e}") | |
verification_results["flash_attn_import"] = False | |
except Exception as e: | |
print(f"β Unexpected error: {e}") | |
verification_results["flash_attn_import"] = False | |
# Check 4: Flash Attention Functionality Test | |
print("\nπ Testing flash attention functionality...") | |
if not cuda_available: | |
print("β οΈ Skipping functionality test - CUDA not available") | |
verification_results["flash_attn_functional"] = False | |
elif not verification_results.get("flash_attn_import", False): | |
print("β οΈ Skipping functionality test - Import failed") | |
verification_results["flash_attn_functional"] = False | |
else: | |
try: | |
from flash_attn import flash_attn_func | |
# Create small dummy tensors | |
batch_size, seq_len, num_heads, head_dim = 1, 16, 4, 32 | |
device = "cuda:0" | |
dtype = torch.float16 | |
print(f"Creating tensors: batch={batch_size}, seq_len={seq_len}, heads={num_heads}, dim={head_dim}") | |
q = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) | |
k = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) | |
v = torch.randn(batch_size, seq_len, num_heads, head_dim, dtype=dtype, device=device) | |
print("β Tensors created successfully") | |
# Test flash attention | |
output = flash_attn_func(q, k, v, dropout_p=0.0, causal=False) | |
print(f"β Flash attention output shape: {output.shape}") | |
print("β Flash attention test passed!") | |
verification_results["flash_attn_functional"] = True | |
except Exception as e: | |
print(f"β Flash attention test failed: {e}") | |
import traceback | |
traceback.print_exc() | |
verification_results["flash_attn_functional"] = False | |
# Summary | |
print("\n" + "=" * 50) | |
print("π VERIFICATION SUMMARY") | |
print("=" * 50) | |
all_passed = True | |
for check_name, result in verification_results.items(): | |
status = "β PASS" if result else "β FAIL" | |
print(f"{check_name}: {status}") | |
if not result: | |
all_passed = False | |
if all_passed: | |
print("\nπ All checks passed! Flash attention should work.") | |
return True | |
else: | |
print("\nβ οΈ Some checks failed. Flash attention may not work properly.") | |
print("\nRecommendations:") | |
print("1. Try reinstalling flash-attn: pip uninstall flash-attn && pip install flash-attn --no-build-isolation") | |
print("2. Check CUDA compatibility with your PyTorch version") | |
print("3. Consider using default attention as fallback") | |
return False | |
class WhisperTranscriber: | |
def __init__(self): | |
self.pipe = pipe # Use global pipeline | |
self.diarization_model = None | |
#@spaces.GPU | |
def setup_models(self): | |
"""Initialize models with GPU acceleration""" | |
if self.pipe is None: | |
print("Loading Whisper model...") | |
self.pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-large-v3-turbo", | |
torch_dtype=torch.float16, | |
device="cuda:0", | |
model_kwargs={"attn_implementation": "flash_attention_2"}, | |
return_timestamps=True, | |
) | |
if self.diarization_model is None: | |
print("Loading diarization model...") | |
# Note: You'll need to set up authentication for pyannote models | |
# For demo purposes, we'll handle the case where it's not available | |
try: | |
self.diarization_model = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", | |
use_auth_token=os.getenv("HF_TOKEN") | |
).to(torch.device("cuda")) | |
except Exception as e: | |
print(f"Could not load diarization model: {e}") | |
self.diarization_model = None | |
def convert_audio_format(self, audio_path): | |
"""Convert audio to 16kHz mono WAV format""" | |
temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
temp_wav_path = temp_wav.name | |
temp_wav.close() | |
try: | |
subprocess.run([ | |
"ffmpeg", "-i", audio_path, | |
"-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", | |
temp_wav_path, "-y" | |
], check=True, capture_output=True) | |
return temp_wav_path | |
except subprocess.CalledProcessError as e: | |
raise RuntimeError(f"Audio conversion failed: {e}") | |
def transcribe_audio(self, audio_path, language=None, translate=False, prompt=None): | |
"""Transcribe audio using Whisper with flash attention""" | |
# Run comprehensive flash attention verification | |
#flash_attention_working = comprehensive_flash_attention_verification() | |
#if not flash_attention_working: | |
# print("β οΈ Flash attention verification failed, but proceeding with transcription...") | |
# print("You may encounter the TypeError: 'NoneType' object is not callable error") | |
''' | |
#if self.pipe is None: | |
# self.setup_models() | |
if next(self.pipe.model.parameters()).device.type != "cuda": | |
self.pipe.model.to("cuda") | |
''' | |
print("Starting transcription...") | |
start_time = time.time() | |
# Prepare generation kwargs | |
generate_kwargs = {} | |
if language: | |
generate_kwargs["language"] = language | |
if translate: | |
generate_kwargs["task"] = "translate" | |
if prompt: | |
generate_kwargs["prompt_ids"] = self.pipe.tokenizer.encode(prompt) | |
# Transcribe with timestamps | |
result = self.pipe( | |
audio_path, | |
return_timestamps=True, | |
generate_kwargs=generate_kwargs, | |
chunk_length_s=30, | |
batch_size=128, | |
) | |
transcription_time = time.time() - start_time | |
print(f"Transcription completed in {transcription_time:.2f} seconds") | |
# Extract segments and detected language | |
segments = [] | |
if "chunks" in result: | |
for chunk in result["chunks"]: | |
segment = { | |
"start": float(chunk["timestamp"][0] or 0), | |
"end": float(chunk["timestamp"][1] or 0), | |
"text": chunk["text"].strip(), | |
} | |
segments.append(segment) | |
else: | |
# Fallback for different result format | |
segments = [{ | |
"start": 0.0, | |
"end": 0.0, | |
"text": result["text"] | |
}] | |
detected_language = getattr(result, 'language', language or 'unknown') | |
transcription_time = time.time() - start_time | |
print(f"Transcription parse completed in {transcription_time:.2f} seconds") | |
return segments, detected_language | |
def perform_diarization(self, audio_path, num_speakers=None): | |
"""Perform speaker diarization""" | |
if self.diarization_model is None: | |
print("Diarization model not available, assigning single speaker") | |
return [], 1 | |
print("Starting diarization...") | |
start_time = time.time() | |
# Load audio for diarization | |
waveform, sample_rate = torchaudio.load(audio_path) | |
# Perform diarization | |
diarization = self.diarization_model( | |
{"waveform": waveform, "sample_rate": sample_rate}, | |
num_speakers=num_speakers, | |
) | |
# Convert to list format | |
diarize_segments = [] | |
diarization_list = list(diarization.itertracks(yield_label=True)) | |
for turn, _, speaker in diarization_list: | |
diarize_segments.append({ | |
"start": turn.start, | |
"end": turn.end, | |
"speaker": speaker | |
}) | |
unique_speakers = {speaker for _, _, speaker in diarization_list} | |
detected_num_speakers = len(unique_speakers) | |
diarization_time = time.time() - start_time | |
print(f"Diarization completed in {diarization_time:.2f} seconds") | |
return diarize_segments, detected_num_speakers | |
def merge_transcription_and_diarization(self, transcription_segments, diarization_segments): | |
"""Merge transcription segments with speaker information""" | |
if not diarization_segments: | |
# No diarization available, assign single speaker | |
for segment in transcription_segments: | |
segment["speaker"] = "SPEAKER_00" | |
return transcription_segments | |
print("Merging transcription and diarization...") | |
diarize_df = pd.DataFrame(diarization_segments) | |
final_segments = [] | |
for segment in transcription_segments: | |
# Calculate intersection with diarization segments | |
diarize_df["intersection"] = np.maximum(0, | |
np.minimum(diarize_df["end"], segment["end"]) - | |
np.maximum(diarize_df["start"], segment["start"]) | |
) | |
# Find speaker with maximum intersection | |
dia_tmp = diarize_df[diarize_df["intersection"] > 0] | |
if len(dia_tmp) > 0: | |
speaker = ( | |
dia_tmp.groupby("speaker")["intersection"] | |
.sum() | |
.sort_values(ascending=False) | |
.index[0] | |
) | |
else: | |
speaker = "SPEAKER_00" | |
segment["speaker"] = speaker | |
segment["duration"] = segment["end"] - segment["start"] | |
final_segments.append(segment) | |
return final_segments | |
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0): | |
"""Group consecutive segments from the same speaker""" | |
if not segments: | |
return segments | |
grouped_segments = [] | |
current_group = segments[0].copy() | |
sentence_end_pattern = r"[.!?]+\s*$" | |
for segment in segments[1:]: | |
time_gap = segment["start"] - current_group["end"] | |
current_duration = current_group["end"] - current_group["start"] | |
# Conditions for combining segments | |
can_combine = ( | |
segment["speaker"] == current_group["speaker"] and | |
time_gap <= max_gap and | |
current_duration < max_duration and | |
not re.search(sentence_end_pattern, current_group["text"]) | |
) | |
if can_combine: | |
# Merge segments | |
current_group["end"] = segment["end"] | |
current_group["text"] += " " + segment["text"] | |
current_group["duration"] = current_group["end"] - current_group["start"] | |
else: | |
# Start new group | |
grouped_segments.append(current_group) | |
current_group = segment.copy() | |
grouped_segments.append(current_group) | |
# Clean up text | |
for segment in grouped_segments: | |
segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip() | |
segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"]) | |
return grouped_segments | |
def process_audio(self, audio_file, num_speakers=None, language=None, | |
translate=False, prompt=None, group_segments=True): | |
"""Main processing function""" | |
if audio_file is None: | |
return {"error": "No audio file provided"} | |
try: | |
# Setup models if not already done | |
#self.setup_models() | |
# Convert audio format | |
#wav_path = self.convert_audio_format(audio_file) | |
try: | |
# Transcribe audio | |
transcription_segments, detected_language = self.transcribe_audio( | |
audio_file, language, translate, prompt | |
) | |
# Perform diarization | |
diarization_segments, detected_num_speakers = self.perform_diarization( | |
audio_file, num_speakers | |
) | |
# Merge transcription and diarization | |
final_segments = self.merge_transcription_and_diarization( | |
transcription_segments, diarization_segments | |
) | |
# Group segments if requested | |
if group_segments: | |
final_segments = self.group_segments_by_speaker(final_segments) | |
return { | |
"segments": final_segments, | |
"language": detected_language, | |
"num_speakers": detected_num_speakers or 1, | |
"total_segments": len(final_segments) | |
} | |
finally: | |
# Clean up temporary file | |
if os.path.exists(audio_file): | |
os.unlink(audio_file) | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
return {"error": f"Processing failed: {str(e)}"} | |
# Initialize transcriber | |
transcriber = WhisperTranscriber() | |
def format_segments_for_display(result): | |
"""Format segments for display in Gradio""" | |
if "error" in result: | |
return f"β Error: {result['error']}" | |
segments = result.get("segments", []) | |
language = result.get("language", "unknown") | |
num_speakers = result.get("num_speakers", 1) | |
output = f"π― **Detection Results:**\n" | |
output += f"- Language: {language}\n" | |
output += f"- Speakers: {num_speakers}\n" | |
output += f"- Segments: {len(segments)}\n\n" | |
output += "π **Transcription:**\n\n" | |
for i, segment in enumerate(segments, 1): | |
start_time = str(datetime.timedelta(seconds=int(segment["start"]))) | |
end_time = str(datetime.timedelta(seconds=int(segment["end"]))) | |
speaker = segment.get("speaker", "SPEAKER_00") | |
text = segment["text"] | |
output += f"**{speaker}** ({start_time} β {end_time})\n" | |
output += f"{text}\n\n" | |
return output | |
def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments): | |
"""Gradio interface function""" | |
result = transcriber.process_audio( | |
audio_file=audio_file, | |
num_speakers=num_speakers if num_speakers > 0 else None, | |
language=language if language != "auto" else None, | |
translate=translate, | |
prompt=prompt if prompt and prompt.strip() else None, | |
group_segments=group_segments | |
) | |
formatted_output = format_segments_for_display(result) | |
return formatted_output, result | |
# Create Gradio interface | |
demo = gr.Blocks( | |
title="ποΈ Whisper Transcription with Speaker Diarization", | |
theme="default" | |
) | |
with demo: | |
gr.Markdown(""" | |
# ποΈ Advanced Audio Transcription & Speaker Diarization | |
Upload an audio file to get accurate transcription with speaker identification, powered by: | |
- **Whisper Large V3 Turbo** with Flash Attention for fast transcription | |
- **Pyannote 3.1** for speaker diarization | |
- **ZeroGPU** acceleration for optimal performance | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
label="π΅ Upload Audio File", | |
type="filepath", | |
#source="upload" | |
) | |
with gr.Accordion("βοΈ Advanced Settings", open=False): | |
num_speakers = gr.Slider( | |
minimum=0, | |
maximum=20, | |
value=0, | |
step=1, | |
label="Number of Speakers (0 = auto-detect)" | |
) | |
language = gr.Dropdown( | |
choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
value="auto", | |
label="Language" | |
) | |
translate = gr.Checkbox( | |
label="Translate to English", | |
value=False | |
) | |
prompt = gr.Textbox( | |
label="Vocabulary Prompt (names, acronyms, etc.)", | |
placeholder="Enter names, technical terms, or context...", | |
lines=2 | |
) | |
group_segments = gr.Checkbox( | |
label="Group segments by speaker", | |
value=True | |
) | |
process_btn = gr.Button("π Transcribe Audio", variant="primary") | |
with gr.Column(): | |
output_text = gr.Markdown( | |
label="π Transcription Results", | |
value="Upload an audio file and click 'Transcribe Audio' to get started!" | |
) | |
output_json = gr.JSON( | |
label="π§ Raw Output (JSON)", | |
visible=False | |
) | |
# Event handlers | |
process_btn.click( | |
fn=process_audio_gradio, | |
inputs=[ | |
audio_input, | |
num_speakers, | |
language, | |
translate, | |
prompt, | |
group_segments | |
], | |
outputs=[output_text, output_json] | |
) | |
# Examples | |
gr.Markdown("### π Usage Tips:") | |
gr.Markdown(""" | |
- **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more | |
- **Max duration**: Recommended under 10 minutes for optimal performance | |
- **Speaker detection**: Works best with clear, distinct voices | |
- **Languages**: Supports 100+ languages with auto-detection | |
- **Vocabulary**: Add names and technical terms in the prompt for better accuracy | |
""") | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |