Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import spaces | |
| import os, pathlib | |
| CACHE_ROOT = "/home/user/app/cache" # any folder you own | |
| os.environ.update( | |
| TORCH_HOME = f"{CACHE_ROOT}/torch", | |
| XDG_CACHE_HOME = f"{CACHE_ROOT}/xdg", # torch fallback | |
| PYANNOTE_CACHE = f"{CACHE_ROOT}/pyannote", | |
| HF_HOME = f"{CACHE_ROOT}/huggingface", | |
| TRANSFORMERS_CACHE= f"{CACHE_ROOT}/transformers", | |
| MPLCONFIGDIR = f"{CACHE_ROOT}/mpl", | |
| ) | |
| # make sure the directories exist | |
| for path in os.environ.values(): | |
| pathlib.Path(path).mkdir(parents=True, exist_ok=True) | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import pandas as pd | |
| import time | |
| import datetime | |
| import re | |
| import subprocess | |
| import os | |
| import tempfile | |
| import spaces | |
| from faster_whisper import WhisperModel | |
| from faster_whisper.vad import VadOptions | |
| import requests | |
| import base64 | |
| from pyannote.audio import Pipeline | |
| import os, importlib.util, pathlib, site | |
| spec = importlib.util.find_spec("nvidia.cudnn") | |
| if spec is not None: # wheel is installed | |
| cudnn_dir = pathlib.Path(spec.origin).parent / "lib" | |
| os.environ["LD_LIBRARY_PATH"] = ( | |
| f"{cudnn_dir}:{os.environ.get('LD_LIBRARY_PATH','')}" | |
| ) | |
| # Lazy global holder ---------------------------------------------------------- | |
| _whisper = None | |
| _diarizer = None | |
| # Create global diarization pipeline | |
| try: | |
| print("Loading diarization model...") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision('high') | |
| _diarizer = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=os.getenv("HF_TOKEN"), | |
| torch_dtype=torch.float16, | |
| ).to(torch.device("cuda")) | |
| _diarizer.model.half() # FP16 | |
| for m in _diarizer.model.modules(): # compact LSTM weights | |
| if isinstance(m, torch.nn.LSTM): | |
| m.flatten_parameters() | |
| _diarizer.model = torch.compile(_diarizer.model, mode="reduce-overhead") | |
| print("Diarization model loaded successfully") | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"Could not load diarization model: {e}") | |
| _diarizer = None | |
| # GPU is guaranteed to exist *inside* this function | |
| def _load_models(): | |
| global _whisper, _diarizer | |
| if _whisper is None: | |
| print("Loading Whisper model...") | |
| _whisper = WhisperModel( | |
| "large-v3-turbo", | |
| device="cuda", | |
| compute_type="float16", | |
| ) | |
| print("Whisper model loaded successfully") | |
| return _whisper, _diarizer | |
| # ----------------------------------------------------------------------------- | |
| class WhisperTranscriber: | |
| def __init__(self): | |
| # do **not** create the models here! | |
| pass | |
| def convert_audio_format(self, audio_path): | |
| """Convert audio to 16kHz mono WAV format""" | |
| temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| temp_wav_path = temp_wav.name | |
| temp_wav.close() | |
| try: | |
| subprocess.run([ | |
| "ffmpeg", "-i", audio_path, | |
| "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", | |
| temp_wav_path, "-y" | |
| ], check=True, capture_output=True) | |
| return temp_wav_path | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"Audio conversion failed: {e}") | |
| def cut_audio_segments(self, audio_path, diarization_segments): | |
| """Cut audio into segments based on diarization results""" | |
| print("Cutting audio into segments...") | |
| # Load the full audio | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| audio_segments = [] | |
| for segment in diarization_segments: | |
| start_sample = int(segment["start"] * sample_rate) | |
| end_sample = int(segment["end"] * sample_rate) | |
| # Extract the segment | |
| segment_waveform = waveform[:, start_sample:end_sample] | |
| # Create temporary file for this segment | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| temp_file.close() | |
| # Save the segment | |
| torchaudio.save(temp_file.name, segment_waveform, sample_rate) | |
| audio_segments.append({ | |
| "audio_path": temp_file.name, | |
| "start": segment["start"], | |
| "end": segment["end"], | |
| "speaker": segment["speaker"] | |
| }) | |
| return audio_segments | |
| # each call gets a GPU slice | |
| def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None): | |
| """Transcribe multiple audio segments using faster_whisper""" | |
| whisper, diarizer = _load_models() # models live on the GPU | |
| print(f"Transcribing {len(audio_segments)} audio segments...") | |
| start_time = time.time() | |
| # Prepare options similar to replicate.py | |
| options = dict( | |
| language=language, | |
| beam_size=5, | |
| vad_filter=True, | |
| vad_parameters=VadOptions( | |
| max_speech_duration_s=whisper.feature_extractor.chunk_length, | |
| min_speech_duration_ms=100, | |
| speech_pad_ms=100, | |
| threshold=0.25, | |
| neg_threshold=0.2, | |
| ), | |
| word_timestamps=True, | |
| initial_prompt=prompt, | |
| language_detection_segments=1, | |
| task="translate" if translate else "transcribe", | |
| ) | |
| results = [] | |
| detected_language = None | |
| for i, segment in enumerate(audio_segments): | |
| print(f"Processing segment {i+1}/{len(audio_segments)}") | |
| # Transcribe this segment | |
| segments, transcript_info = whisper.transcribe(segment["audio_path"], **options) | |
| segments = list(segments) | |
| # Get detected language from first segment | |
| if detected_language is None: | |
| detected_language = transcript_info.language | |
| # Process each transcribed segment | |
| for seg in segments: | |
| # Create result entry with detailed format like replicate.py | |
| words_list = [] | |
| if seg.words: | |
| for word in seg.words: | |
| words_list.append({ | |
| "start": float(word.start) + segment["start"], | |
| "end": float(word.end) + segment["start"], | |
| "word": word.word, | |
| "probability": word.probability, | |
| "speaker": segment["speaker"] | |
| }) | |
| results.append({ | |
| "start": float(seg.start) + segment["start"], | |
| "end": float(seg.end) + segment["start"], | |
| "text": seg.text, | |
| "speaker": segment["speaker"], | |
| "avg_logprob": seg.avg_logprob, | |
| "words": words_list, | |
| "duration": float(seg.end - seg.start) | |
| }) | |
| # Clean up temporary files | |
| for segment in audio_segments: | |
| if os.path.exists(segment["audio_path"]): | |
| os.unlink(segment["audio_path"]) | |
| transcription_time = time.time() - start_time | |
| print(f"All segments transcribed in {transcription_time:.2f} seconds") | |
| return results, detected_language | |
| # each call gets a GPU slice | |
| def perform_diarization(self, audio_path, num_speakers=None): | |
| """Perform speaker diarization""" | |
| whisper, diarizer = _load_models() # models live on the GPU | |
| if diarizer is None: | |
| print("Diarization model not available, creating single speaker segment") | |
| # Load audio to get duration | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| duration = waveform.shape[1] / sample_rate | |
| return [{ | |
| "start": 0.0, | |
| "end": duration, | |
| "speaker": "SPEAKER_00" | |
| }], 1 | |
| print("Starting diarization...") | |
| start_time = time.time() | |
| # Load audio for diarization | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Perform diarization | |
| diarization = diarizer( | |
| {"waveform": waveform, "sample_rate": sample_rate}, | |
| num_speakers=num_speakers, | |
| ) | |
| # Convert to list format | |
| diarize_segments = [] | |
| diarization_list = list(diarization.itertracks(yield_label=True)) | |
| for turn, _, speaker in diarization_list: | |
| diarize_segments.append({ | |
| "start": turn.start, | |
| "end": turn.end, | |
| "speaker": speaker | |
| }) | |
| unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]} | |
| detected_num_speakers = len(unique_speakers) | |
| diarization_time = time.time() - start_time | |
| print(f"Diarization completed in {diarization_time:.2f} seconds") | |
| return diarize_segments, detected_num_speakers | |
| def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0): | |
| """Group consecutive segments from the same speaker""" | |
| if not segments: | |
| return segments | |
| grouped_segments = [] | |
| current_group = segments[0].copy() | |
| sentence_end_pattern = r"[.!?]+" | |
| for segment in segments[1:]: | |
| time_gap = segment["start"] - current_group["end"] | |
| current_duration = current_group["end"] - current_group["start"] | |
| # Conditions for combining segments | |
| can_combine = ( | |
| segment["speaker"] == current_group["speaker"] and | |
| time_gap <= max_gap and | |
| current_duration < max_duration and | |
| not re.search(sentence_end_pattern, current_group["text"][-1:]) | |
| ) | |
| if can_combine: | |
| # Merge segments | |
| current_group["end"] = segment["end"] | |
| current_group["text"] += " " + segment["text"] | |
| current_group["words"].extend(segment["words"]) | |
| current_group["duration"] = current_group["end"] - current_group["start"] | |
| else: | |
| # Start new group | |
| grouped_segments.append(current_group) | |
| current_group = segment.copy() | |
| grouped_segments.append(current_group) | |
| # Clean up text | |
| for segment in grouped_segments: | |
| segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip() | |
| segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"]) | |
| return grouped_segments | |
| # each call gets a GPU slice | |
| def process_audio(self, audio_file, num_speakers=None, language=None, | |
| translate=False, prompt=None, group_segments=True): | |
| """Main processing function - diarization first, then transcription""" | |
| if audio_file is None: | |
| return {"error": "No audio file provided"} | |
| converted_audio_path = None | |
| try: | |
| print("Starting new processing pipeline...") | |
| # Step 1: Convert audio format first | |
| print("Converting audio format...") | |
| converted_audio_path = self.convert_audio_format(audio_file) | |
| # Step 2: Perform diarization on converted audio | |
| diarization_segments, detected_num_speakers = self.perform_diarization( | |
| converted_audio_path, num_speakers | |
| ) | |
| # Step 3: Cut audio into segments based on diarization | |
| audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments) | |
| # Step 4: Transcribe each segment | |
| transcription_results, detected_language = self.transcribe_audio_segments( | |
| audio_segments, language, translate, prompt | |
| ) | |
| # Step 5: Group segments if requested | |
| if group_segments: | |
| transcription_results = self.group_segments_by_speaker(transcription_results) | |
| # Step 6: Return in replicate.py format | |
| return { | |
| "segments": transcription_results, | |
| "language": detected_language, | |
| "num_speakers": detected_num_speakers | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| finally: | |
| # Clean up converted audio file | |
| if converted_audio_path and os.path.exists(converted_audio_path): | |
| os.unlink(converted_audio_path) | |
| print("Cleaned up converted audio file") | |
| # Initialize transcriber | |
| transcriber = WhisperTranscriber() | |
| def format_segments_for_display(result): | |
| """Format segments for display in Gradio""" | |
| if "error" in result: | |
| return f"β Error: {result['error']}" | |
| segments = result.get("segments", []) | |
| language = result.get("language", "unknown") | |
| num_speakers = result.get("num_speakers", 1) | |
| output = f"π― **Detection Results:**\n" | |
| output += f"- Language: {language}\n" | |
| output += f"- Speakers: {num_speakers}\n" | |
| output += f"- Segments: {len(segments)}\n\n" | |
| output += "π **Transcription:**\n\n" | |
| for i, segment in enumerate(segments, 1): | |
| start_time = str(datetime.timedelta(seconds=int(segment["start"]))) | |
| end_time = str(datetime.timedelta(seconds=int(segment["end"]))) | |
| speaker = segment.get("speaker", "SPEAKER_00") | |
| text = segment["text"] | |
| output += f"**{speaker}** ({start_time} β {end_time})\n" | |
| output += f"{text}\n\n" | |
| return output | |
| def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments): | |
| """Gradio interface function""" | |
| result = transcriber.process_audio( | |
| audio_file=audio_file, | |
| num_speakers=num_speakers if num_speakers > 0 else None, | |
| language=language if language != "auto" else None, | |
| translate=translate, | |
| prompt=prompt if prompt and prompt.strip() else None, | |
| group_segments=group_segments | |
| ) | |
| formatted_output = format_segments_for_display(result) | |
| return formatted_output, result | |
| # Create Gradio interface | |
| demo = gr.Blocks( | |
| title="ποΈ Whisper Transcription with Speaker Diarization", | |
| theme="default" | |
| ) | |
| with demo: | |
| gr.Markdown(""" | |
| # ποΈ Advanced Audio Transcription & Speaker Diarization | |
| Upload an audio file to get accurate transcription with speaker identification, powered by: | |
| - **Whisper Large V3 Turbo** with Flash Attention for fast transcription | |
| - **Pyannote 3.1** for speaker diarization | |
| - **ZeroGPU** acceleration for optimal performance | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="π΅ Upload Audio File", | |
| type="filepath", | |
| #source="upload" | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| num_speakers = gr.Slider( | |
| minimum=0, | |
| maximum=20, | |
| value=0, | |
| step=1, | |
| label="Number of Speakers (0 = auto-detect)" | |
| ) | |
| language = gr.Dropdown( | |
| choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
| value="auto", | |
| label="Language" | |
| ) | |
| translate = gr.Checkbox( | |
| label="Translate to English", | |
| value=False | |
| ) | |
| prompt = gr.Textbox( | |
| label="Vocabulary Prompt (names, acronyms, etc.)", | |
| placeholder="Enter names, technical terms, or context...", | |
| lines=2 | |
| ) | |
| group_segments = gr.Checkbox( | |
| label="Group segments by speaker", | |
| value=True | |
| ) | |
| process_btn = gr.Button("π Transcribe Audio", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Markdown( | |
| label="π Transcription Results", | |
| value="Upload an audio file and click 'Transcribe Audio' to get started!" | |
| ) | |
| output_json = gr.JSON( | |
| label="π§ Raw Output (JSON)", | |
| visible=False | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_audio_gradio, | |
| inputs=[ | |
| audio_input, | |
| num_speakers, | |
| language, | |
| translate, | |
| prompt, | |
| group_segments | |
| ], | |
| outputs=[output_text, output_json] | |
| ) | |
| # Examples | |
| gr.Markdown("### π Usage Tips:") | |
| gr.Markdown(""" | |
| - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more | |
| - **Max duration**: Recommended under 10 minutes for optimal performance | |
| - **Speaker detection**: Works best with clear, distinct voices | |
| - **Languages**: Supports 100+ languages with auto-detection | |
| - **Vocabulary**: Add names and technical terms in the prompt for better accuracy | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |