# %%writefile /content/VibeVoice/demo/colab.py # Original Code: https://github.com/microsoft/VibeVoice/blob/main/demo/gradio_demo.py """ VibeVoice Gradio Demo """ import json import os import sys import tempfile import time from pathlib import Path from typing import List, Dict, Any, Iterator from datetime import datetime import threading import numpy as np import gradio as gr import librosa import soundfile as sf import torch import os import traceback import shutil import re import uuid from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference, VibeVoiceGenerationOutput from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor from vibevoice.modular.streamer import AudioStreamer from transformers import set_seed from pydub import AudioSegment from pydub.silence import split_on_silence def drive_save(file_copy): drive_path = "/content/gdrive/MyDrive" save_folder = os.path.join(drive_path, "VibeVoice_Podcast") if os.path.exists(drive_path): print("Running on Google Colab and auto-saving to Google Drive...") os.makedirs(save_folder, exist_ok=True) dest_path = os.path.join(save_folder, os.path.basename(file_copy)) shutil.copy2(file_copy, dest_path) print(f"File saved to: {dest_path}") return dest_path else: print("Not running on Google Colab (or Google Drive not mounted). Skipping auto-save.") return None import os, requests, urllib.request, urllib.error from tqdm.auto import tqdm def download_file(url, download_file_path, redownload=False): base_path = os.path.dirname(download_file_path) os.makedirs(base_path, exist_ok=True) if os.path.exists(download_file_path): if redownload: os.remove(download_file_path) tqdm.write(f"♻️ Redownloading: {os.path.basename(download_file_path)}") elif os.path.getsize(download_file_path) > 0: tqdm.write(f"✔️ Skipped (already exists): {os.path.basename(download_file_path)}") return True try: request = urllib.request.urlopen(url) total = int(request.headers.get('Content-Length', 0)) except urllib.error.URLError as e: print(f"❌ Error: Unable to open URL: {url}") print(f"Reason: {e.reason}") return False with tqdm(total=total, desc=os.path.basename(download_file_path), unit='B', unit_scale=True, unit_divisor=1024) as progress: try: urllib.request.urlretrieve( url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size) ) except urllib.error.URLError as e: print(f"❌ Error: Failed to download {url}") print(f"Reason: {e.reason}") return False tqdm.write(f"⬇️ Downloaded: {os.path.basename(download_file_path)}") return True def download_model(repo_id, download_folder="./", redownload=False): if not download_folder.strip(): download_folder = "." url = f"https://huggingface.co/api/models/{repo_id}" download_dir = os.path.abspath(f"{download_folder.rstrip('/')}/{repo_id.split('/')[-1]}") os.makedirs(download_dir, exist_ok=True) print(f"📂 Download directory: {download_dir}") response = requests.get(url) if response.status_code != 200: print("❌ Error:", response.status_code, response.text) return None data = response.json() siblings = data.get("siblings", []) files = [f["rfilename"] for f in siblings] print(f"📦 Found {len(files)} files in repo '{repo_id}'. Checking cache ...") for file in tqdm(files, desc="Processing files", unit="file"): file_url = f"https://huggingface.co/{repo_id}/resolve/main/{file}" file_path = os.path.join(download_dir, file) download_file(file_url, file_path, redownload=redownload) return download_dir def generate_file_name(text): output_dir = "./podcast_audio" os.makedirs(output_dir, exist_ok=True) cleaned = re.sub(r"^\s*speaker\s*\d+\s*:\s*", "", text, flags=re.IGNORECASE) short = cleaned[:30].strip() short = re.sub(r'[^a-zA-Z0-9\s]', '', short) short = short.lower().strip().replace(" ", "_") if not short: short = "podcast_output" unique_name = f"{short}_{uuid.uuid4().hex[:6]}" return os.path.join(output_dir, unique_name) class VibeVoiceDemo: def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): self.model_path = model_path self.device = device self.inference_steps = inference_steps self.is_generating = False self.stop_generation = False self.load_model() self.setup_voice_presets() self.load_example_scripts() def load_model(self): print(f"Loading processor & model from {self.model_path}") self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) if self.device == "cuda": self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( self.model_path, torch_dtype=torch.bfloat16, device_map=self.device, ) else: self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( self.model_path, torch_dtype=torch.float32, device_map="cpu", ) self.model.eval() self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( self.model.model.noise_scheduler.config, algorithm_type='sde-dpmsolver++', beta_schedule='squaredcos_cap_v2' ) self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) if hasattr(self.model.model, 'language_model'): print(f"Language model attention: {self.model.model.language_model.config._attn_implementation}") def setup_voice_presets(self): voices_dir = os.path.join(os.path.dirname(__file__), "voices") if not os.path.exists(voices_dir): print(f"Warning: Voices directory not found at {voices_dir}, creating it.") os.makedirs(voices_dir, exist_ok=True) self.voice_presets = {} audio_files = [f for f in os.listdir(voices_dir) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))] for audio_file in audio_files: name = os.path.splitext(audio_file)[0] self.voice_presets[name] = os.path.join(voices_dir, audio_file) self.voice_presets = dict(sorted(self.voice_presets.items())) self.available_voices = {name: path for name, path in self.voice_presets.items() if os.path.exists(path)} if not self.available_voices: print("Warning: No voice presets found.") print(f"Found {len(self.available_voices)} voice files in {voices_dir}") def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: try: wav, sr = sf.read(audio_path) if len(wav.shape) > 1: wav = np.mean(wav, axis=1) if sr != target_sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) return wav except Exception as e: print(f"Error reading audio {audio_path}: {e}") return np.array([]) def trim_silence_from_numpy(self, audio_np: np.ndarray, sample_rate: int, silence_thresh: int = -45, min_silence_len: int = 100, keep_silence: int = 50) -> np.ndarray: audio_int16 = (audio_np * 32767).astype(np.int16) sound = AudioSegment(data=audio_int16.tobytes(), sample_width=audio_int16.dtype.itemsize, frame_rate=sample_rate, channels=1) audio_chunks = split_on_silence(sound, min_silence_len=min_silence_len, silence_thresh=silence_thresh, keep_silence=keep_silence) if not audio_chunks: return np.array([0.0], dtype=np.float32) combined = sum(audio_chunks) samples = np.array(combined.get_array_of_samples()) return samples.astype(np.float32) / 32767.0 def generate_podcast_with_timestamps(self, num_speakers: int, script: str, speaker_1: str, speaker_2: str, speaker_3: str, speaker_4: str, cfg_scale: float, remove_silence: bool, progress=gr.Progress()): # Initial UI state: Clear previous results, show stop button yield None, None, None, gr.update(visible=False), gr.update(visible=True) final_audio_path, final_json_path = None, None try: self.stop_generation = False self.is_generating = True if not script.strip(): raise gr.Error("Error: Please provide a script.") script = script.replace("’", "'") if not 1 <= num_speakers <= 4: raise gr.Error("Error: Number of speakers must be between 1 and 4.") selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers] for i, speaker in enumerate(selected_speakers): if not speaker or speaker not in self.available_voices: raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.") voice_samples = [self.read_audio(self.available_voices[name]) for name in selected_speakers] if any(len(vs) == 0 for vs in voice_samples): raise gr.Error("Error: Failed to load one or more audio files.") lines = script.strip().split('\n') formatted_script_lines = [] for line in lines: line = line.strip() if not line: continue if re.match(r'Speaker\s*\d+:', line, re.IGNORECASE): formatted_script_lines.append(line) else: speaker_id = len(formatted_script_lines) % num_speakers formatted_script_lines.append(f"Speaker {speaker_id+1}: {line}") if not formatted_script_lines: raise gr.Error("Error: Script is empty after formatting.") timestamps = {} current_time = 0.0 sample_rate = 24000 base_filename = generate_file_name(formatted_script_lines[0]) final_audio_path = base_filename + ".wav" final_json_path = base_filename + ".json" with sf.SoundFile(final_audio_path, 'w', samplerate=sample_rate, channels=1, subtype='PCM_16') as audio_file: for i, line in enumerate(formatted_script_lines): if self.stop_generation: print("\n🚫 Generation interrupted by user. Finalizing partial files...") break progress(i / len(formatted_script_lines), desc=f"Generating line {i+1}/{len(formatted_script_lines)}") match = re.match(r'Speaker\s*(\d+):\s*(.*)', line, re.IGNORECASE) if not match: continue speaker_idx = int(match.group(1)) - 1 text_content = match.group(2).strip() if not (0 <= speaker_idx < len(voice_samples)): continue inputs = self.processor(text=[line], voice_samples=[voice_samples[speaker_idx]], padding=True, return_tensors="pt") output_waveform = self.model.generate(**inputs, max_new_tokens=None, cfg_scale=cfg_scale, tokenizer=self.processor.tokenizer, generation_config={'do_sample': False}, verbose=False, refresh_negative=True) audio_np = output_waveform.speech_outputs[0].cpu().float().numpy().squeeze() if remove_silence: audio_np = self.trim_silence_from_numpy(audio_np, sample_rate) duration = len(audio_np) / sample_rate audio_file.write((audio_np * 32767).astype(np.int16)) timestamps[str(i + 1)] = {"text": text_content, "speaker_id": speaker_idx + 1, "start": current_time, "end": current_time + duration} current_time += duration if not timestamps: self.is_generating = False if os.path.exists(final_audio_path): os.remove(final_audio_path) yield None, None, None, gr.update(visible=True), gr.update(visible=False) return progress(1.0, desc="Saving generated files...") with open(final_json_path, "w") as f: json.dump(timestamps, f, indent=2) try: drive_save(final_audio_path) drive_save(final_json_path) except Exception as e: print(f"Error saving files to Google Drive: {e}") message = "Partial" if self.stop_generation else "Full" print(f"\n✨ {message} generation successful!\n🎵 Audio: {final_audio_path}\n📄 Timestamps: {final_json_path}\n") self.is_generating = False yield final_audio_path, final_audio_path, final_json_path, gr.update(visible=True), gr.update(visible=False) except Exception as e: self.is_generating = False print(f"❌ An unexpected error occurred: {str(e)}") traceback.print_exc() try: if final_audio_path and os.path.exists(final_audio_path): os.remove(final_audio_path) if final_json_path and os.path.exists(final_json_path): os.remove(final_json_path) except Exception as cleanup_e: print(f"Error during cleanup after exception: {cleanup_e}") yield None, None, None, gr.update(visible=True), gr.update(visible=False) def stop_audio_generation(self): if self.is_generating: self.stop_generation = True print("🛑 Audio generation stop requested") def load_example_scripts(self): examples_dir = os.path.join(os.path.dirname(__file__), "text_examples") self.example_scripts = [] if not os.path.exists(examples_dir): return txt_files = sorted([f for f in os.listdir(examples_dir) if f.lower().endswith('.txt')]) for txt_file in txt_files: try: with open(os.path.join(examples_dir, txt_file), 'r', encoding='utf-8') as f: script = f.read().strip() if script: self.example_scripts.append([self._get_num_speakers_from_script(script), script]) except Exception as e: print(f"Error loading example {txt_file}: {e}") def _get_num_speakers_from_script(self, script: str) -> int: speakers = set(re.findall(r'^Speaker\s+(\d+)\s*:', script, re.MULTILINE | re.IGNORECASE)) return max(int(s) for s in speakers) if speakers else 1 def create_demo_interface(demo_instance: VibeVoiceDemo): with gr.Blocks(title="VibeVoice AI Podcast Generator") as interface: gr.HTML("""
Generate Long-form Multi-speaker AI Podcasts with VibeVoice
🥳 Run on Google ColabPaste the prompt into any LLM, and customize the propmt if you want.