Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import time | |
| import numpy as np | |
| import gradio as gr | |
| import librosa | |
| import soundfile as sf | |
| import torch | |
| import traceback | |
| import threading | |
| from spaces import GPU | |
| from datetime import datetime | |
| from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference | |
| from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor | |
| from vibevoice.modular.streamer import AudioStreamer | |
| from transformers.utils import logging | |
| from transformers import set_seed | |
| logging.set_verbosity_info() | |
| logger = logging.get_logger(__name__) | |
| class VibeVoiceDemo: | |
| def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): | |
| self.model_path = model_path | |
| self.device = device | |
| self.inference_steps = inference_steps | |
| self.is_generating = False | |
| self.processor = None | |
| self.model = None | |
| self.available_voices = {} | |
| self.load_model() | |
| self.setup_voice_presets() | |
| self.load_example_scripts() | |
| def load_model(self): | |
| print(f"Loading processor & model from {self.model_path}") | |
| self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) | |
| self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
| self.model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map=self.device | |
| ) | |
| self.model.eval() | |
| self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
| def setup_voice_presets(self): | |
| voices_dir = os.path.join(os.path.dirname(__file__), "voices") | |
| if not os.path.exists(voices_dir): | |
| print(f"Warning: Voices directory not found at {voices_dir}") | |
| return | |
| wav_files = [f for f in os.listdir(voices_dir) | |
| if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac'))] | |
| for wav_file in wav_files: | |
| name = os.path.splitext(wav_file)[0] | |
| self.available_voices[name] = os.path.join(voices_dir, wav_file) | |
| print(f"Voices loaded: {list(self.available_voices.keys())}") | |
| def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: | |
| try: | |
| wav, sr = sf.read(audio_path) | |
| if len(wav.shape) > 1: | |
| wav = np.mean(wav, axis=1) | |
| if sr != target_sr: | |
| wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) | |
| return wav | |
| except Exception as e: | |
| print(f"Error reading audio {audio_path}: {e}") | |
| return np.array([]) | |
| def generate_podcast(self, num_speakers: int, script: str, | |
| speaker_1: str = None, speaker_2: str = None, | |
| speaker_3: str = None, speaker_4: str = None, | |
| cfg_scale: float = 1.3): | |
| """Final audio generation only (no streaming, runs fully on GPU).""" | |
| self.is_generating = True | |
| self.stop_generation = False | |
| if not script.strip(): | |
| raise gr.Error("Please provide a script.") | |
| if num_speakers < 1 or num_speakers > 4: | |
| raise gr.Error("Number of speakers must be 1β4.") | |
| # Collect selected speakers | |
| selected = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers] | |
| for i, sp in enumerate(selected): | |
| if not sp or sp not in self.available_voices: | |
| raise gr.Error(f"Invalid speaker {i+1} selection.") | |
| # Load voices into memory | |
| voice_samples = [self.read_audio(self.available_voices[sp]) for sp in selected] | |
| if any(len(v) == 0 for v in voice_samples): | |
| raise gr.Error("Failed to load one or more voice samples.") | |
| # Format script | |
| lines = script.strip().split("\n") | |
| formatted = [] | |
| for i, line in enumerate(lines): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if line.startswith("Speaker "): | |
| formatted.append(line) | |
| else: | |
| sp_id = i % num_speakers | |
| formatted.append(f"Speaker {sp_id}: {line}") | |
| formatted_script = "\n".join(formatted) | |
| # Prepare processor inputs | |
| inputs = self.processor( | |
| text=[formatted_script], | |
| voice_samples=[voice_samples], | |
| padding=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| start = time.time() | |
| sample_rate = 24000 | |
| audio_streamer = AudioStreamer(batch_size=1) | |
| # Run generation fully on GPU | |
| self.model.generate( | |
| **inputs, | |
| max_new_tokens=None, | |
| cfg_scale=cfg_scale, | |
| tokenizer=self.processor.tokenizer, | |
| generation_config={'do_sample': False}, | |
| audio_streamer=audio_streamer, | |
| verbose=False, | |
| ) | |
| # Collect all audio chunks | |
| all_chunks = [] | |
| for audio_chunk in audio_streamer.get_stream(0): | |
| if torch.is_tensor(audio_chunk): | |
| audio_chunk = audio_chunk.float().cpu().numpy() | |
| if audio_chunk.ndim > 1: | |
| audio_chunk = audio_chunk.squeeze() | |
| all_chunks.append(audio_chunk) | |
| if not all_chunks: | |
| self.is_generating = False | |
| raise gr.Error("β No audio was generated by the model.") | |
| complete_audio = np.concatenate(all_chunks) | |
| audio16 = convert_to_16_bit_wav(complete_audio) | |
| # Save automatically to disk | |
| os.makedirs("outputs", exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| file_path = os.path.join("outputs", f"podcast_{timestamp}.wav") | |
| sf.write(file_path, complete_audio, sample_rate) | |
| print(f"πΎ Saved podcast to {file_path}") | |
| total_dur = len(complete_audio) / sample_rate | |
| log = f"β Generation complete in {time.time()-start:.1f}s, {total_dur:.1f}s audio\nSaved to {file_path}" | |
| self.is_generating = False | |
| return (sample_rate, complete_audio), log | |
| def load_example_scripts(self): | |
| examples_dir = os.path.join(os.path.dirname(__file__), "text_examples") | |
| self.example_scripts = [] | |
| if not os.path.exists(examples_dir): | |
| return | |
| txt_files = sorted([f for f in os.listdir(examples_dir) | |
| if f.lower().endswith('.txt')]) | |
| for txt_file in txt_files: | |
| try: | |
| with open(os.path.join(examples_dir, txt_file), 'r', encoding='utf-8') as f: | |
| script_content = f.read().strip() | |
| if script_content: | |
| self.example_scripts.append([1, script_content]) | |
| except Exception as e: | |
| print(f"Error loading {txt_file}: {e}") | |
| def convert_to_16_bit_wav(data): | |
| if torch.is_tensor(data): | |
| data = data.detach().cpu().numpy() | |
| data = np.array(data) | |
| if np.max(np.abs(data)) > 1.0: | |
| data = data / np.max(np.abs(data)) | |
| return (data * 32767).astype(np.int16) | |
| def create_demo_interface(demo_instance: VibeVoiceDemo): | |
| """Create the Gradio interface (final audio only, no streaming).""" | |
| # Custom CSS for high-end aesthetics | |
| custom_css = """ ... """ # (keep your CSS unchanged) | |
| with gr.Blocks( | |
| title="VibeVoice - AI Podcast Generator", | |
| css=custom_css, | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="purple", | |
| neutral_hue="slate", | |
| ) | |
| ) as interface: | |
| # Header | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>ποΈ Vibe Podcasting</h1> | |
| <p>Generating Long-form Multi-speaker AI Podcast with VibeVoice</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| # Left column - Settings | |
| with gr.Column(scale=1, elem_classes="settings-card"): | |
| gr.Markdown("### ποΈ **Podcast Settings**") | |
| num_speakers = gr.Slider( | |
| minimum=1, maximum=4, value=2, step=1, | |
| label="Number of Speakers", | |
| elem_classes="slider-container" | |
| ) | |
| gr.Markdown("### π **Speaker Selection**") | |
| available_speaker_names = list(demo_instance.available_voices.keys()) | |
| default_speakers = ['en-Alice_woman', 'en-Carter_man', 'en-Frank_man', 'en-Maya_woman'] | |
| speaker_selections = [] | |
| for i in range(4): | |
| default_value = default_speakers[i] if i < len(default_speakers) else None | |
| speaker = gr.Dropdown( | |
| choices=available_speaker_names, | |
| value=default_value, | |
| label=f"Speaker {i+1}", | |
| visible=(i < 2), | |
| elem_classes="speaker-item" | |
| ) | |
| speaker_selections.append(speaker) | |
| gr.Markdown("### βοΈ **Advanced Settings**") | |
| with gr.Accordion("Generation Parameters", open=False): | |
| cfg_scale = gr.Slider( | |
| minimum=1.0, maximum=2.0, value=1.3, step=0.05, | |
| label="CFG Scale (Guidance Strength)", | |
| elem_classes="slider-container" | |
| ) | |
| # Right column - Generation | |
| with gr.Column(scale=2, elem_classes="generation-card"): | |
| gr.Markdown("### π **Script Input**") | |
| script_input = gr.Textbox( | |
| label="Conversation Script", | |
| placeholder="Enter your podcast script here...", | |
| lines=12, | |
| max_lines=20, | |
| elem_classes="script-input" | |
| ) | |
| with gr.Row(): | |
| random_example_btn = gr.Button( | |
| "π² Random Example", size="lg", | |
| variant="secondary", elem_classes="random-btn", scale=1 | |
| ) | |
| generate_btn = gr.Button( | |
| "π Generate Podcast", size="lg", | |
| variant="primary", elem_classes="generate-btn", scale=2 | |
| ) | |
| # Output section | |
| gr.Markdown("### π΅ **Generated Podcast**") | |
| complete_audio_output = gr.Audio( | |
| label="Complete Podcast (Download)", | |
| type="numpy", | |
| elem_classes="audio-output complete-audio-section", | |
| autoplay=False, | |
| show_download_button=True, | |
| visible=True | |
| ) | |
| log_output = gr.Textbox( | |
| label="Generation Log", | |
| lines=8, max_lines=15, | |
| interactive=False, | |
| elem_classes="log-output" | |
| ) | |
| # === logic === | |
| def update_speaker_visibility(num_speakers): | |
| return [gr.update(visible=(i < num_speakers)) for i in range(4)] | |
| num_speakers.change( | |
| fn=update_speaker_visibility, | |
| inputs=[num_speakers], | |
| outputs=speaker_selections | |
| ) | |
| def generate_podcast_wrapper(num_speakers, script, *speakers_and_params): | |
| try: | |
| speakers = speakers_and_params[:4] | |
| cfg_scale = speakers_and_params[4] | |
| audio, log = demo_instance.generate_podcast( | |
| num_speakers=int(num_speakers), | |
| script=script, | |
| speaker_1=speakers[0], | |
| speaker_2=speakers[1], | |
| speaker_3=speakers[2], | |
| speaker_4=speakers[3], | |
| cfg_scale=cfg_scale | |
| ) | |
| return audio, log | |
| except Exception as e: | |
| traceback.print_exc() | |
| return None, f"β Error: {str(e)}" | |
| generate_btn.click( | |
| fn=generate_podcast_wrapper, | |
| inputs=[num_speakers, script_input] + speaker_selections + [cfg_scale], | |
| outputs=[complete_audio_output, log_output], | |
| queue=True | |
| ) | |
| def load_random_example(): | |
| import random | |
| examples = getattr(demo_instance, "example_scripts", []) | |
| if not examples: | |
| examples = [ | |
| [2, "Speaker 0: Welcome to our AI podcast demo!\nSpeaker 1: Thanks, excited to be here!"] | |
| ] | |
| num_speakers_value, script_value = random.choice(examples) | |
| return num_speakers_value, script_value | |
| random_example_btn.click( | |
| fn=load_random_example, | |
| inputs=[], | |
| outputs=[num_speakers, script_input], | |
| queue=False | |
| ) | |
| gr.Markdown("### π **Example Scripts**") | |
| examples = getattr(demo_instance, "example_scripts", []) or [ | |
| [1, "Speaker 1: Welcome to our AI podcast demo. This is a sample script."] | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[num_speakers, script_input], | |
| label="Try these example scripts:" | |
| ) | |
| return interface | |
| def run_demo( | |
| model_path: str = "microsoft/VibeVoice-1.5B", | |
| device: str = "cuda", | |
| inference_steps: int = 5, | |
| share: bool = True, | |
| ): | |
| set_seed(42) | |
| demo_instance = VibeVoiceDemo(model_path, device, inference_steps) | |
| interface = create_demo_interface(demo_instance) | |
| interface.queue().launch( | |
| share=share, | |
| server_name="0.0.0.0" if share else "127.0.0.1", | |
| show_error=True, | |
| show_api=False | |
| ) | |
| if __name__ == "__main__": | |
| run_demo() | |