# %%writefile /content/VibeVoice/demo/colab.py # Original Code: https://github.com/microsoft/VibeVoice/blob/main/demo/gradio_demo.py """ VibeVoice Gradio Demo """ import json import os import sys import tempfile import time from pathlib import Path from typing import List, Dict, Any, Iterator from datetime import datetime import threading import numpy as np import gradio as gr import librosa import soundfile as sf import torch import os import traceback import shutil import re # Added for timestamp feature import uuid # Added for timestamp feature from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference, VibeVoiceGenerationOutput from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor from vibevoice.modular.streamer import AudioStreamer from transformers import set_seed from pydub import AudioSegment from pydub.silence import split_on_silence def drive_save(file_copy): drive_path = "/content/gdrive/MyDrive" save_folder = os.path.join(drive_path, "VibeVoice_Podcast") if os.path.exists(drive_path): print("Running on Google Colab and auto-saving to Google Drive...") os.makedirs(save_folder, exist_ok=True) dest_path = os.path.join(save_folder, os.path.basename(file_copy)) shutil.copy2(file_copy, dest_path) # preserves metadata print(f"File saved to: {dest_path}") return dest_path else: print("Not running on Google Colab (or Google Drive not mounted). Skipping auto-save.") return None import os, requests, urllib.request, urllib.error from tqdm.auto import tqdm def download_file(url, download_file_path, redownload=False): """Download a single file with urllib + tqdm progress bar.""" base_path = os.path.dirname(download_file_path) os.makedirs(base_path, exist_ok=True) # skip logic if os.path.exists(download_file_path): if redownload: os.remove(download_file_path) tqdm.write(f"♻️ Redownloading: {os.path.basename(download_file_path)}") elif os.path.getsize(download_file_path) > 0: tqdm.write(f"✔️ Skipped (already exists): {os.path.basename(download_file_path)}") return True try: request = urllib.request.urlopen(url) total = int(request.headers.get('Content-Length', 0)) except urllib.error.URLError as e: print(f"❌ Error: Unable to open URL: {url}") print(f"Reason: {e.reason}") return False with tqdm(total=total, desc=os.path.basename(download_file_path), unit='B', unit_scale=True, unit_divisor=1024) as progress: try: urllib.request.urlretrieve( url, download_file_path, reporthook=lambda count, block_size, total_size: progress.update(block_size) ) except urllib.error.URLError as e: print(f"❌ Error: Failed to download {url}") print(f"Reason: {e.reason}") return False tqdm.write(f"⬇️ Downloaded: {os.path.basename(download_file_path)}") return True def download_model(repo_id, download_folder="./", redownload=False): # normalize empty string as current dir if not download_folder.strip(): download_folder = "." url = f"https://huggingface.co/api/models/{repo_id}" download_dir = os.path.abspath(f"{download_folder.rstrip('/')}/{repo_id.split('/')[-1]}") os.makedirs(download_dir, exist_ok=True) print(f"📂 Download directory: {download_dir}") response = requests.get(url) if response.status_code != 200: print("❌ Error:", response.status_code, response.text) return None data = response.json() siblings = data.get("siblings", []) files = [f["rfilename"] for f in siblings] print(f"📦 Found {len(files)} files in repo '{repo_id}'. Checking cache ...") for file in tqdm(files, desc="Processing files", unit="file"): file_url = f"https://huggingface.co/{repo_id}/resolve/main/{file}" file_path = os.path.join(download_dir, file) download_file(file_url, file_path, redownload=redownload) return download_dir # NEW FEATURE: Function to generate unique filenames for output def generate_file_name(text): """Generates a unique, clean filename based on the script's first line.""" output_dir = "./podcast_audio" os.makedirs(output_dir, exist_ok=True) # Clean the text to get a base for the filename cleaned = re.sub(r"^\s*speaker\s*\d+\s*:\s*", "", text, flags=re.IGNORECASE) short = cleaned[:30].strip() short = re.sub(r'[^a-zA-Z0-9\s]', '', short) short = short.lower().strip().replace(" ", "_") if not short: short = "podcast_output" # Add a unique identifier unique_name = f"{short}_{uuid.uuid4().hex[:6]}" return os.path.join(output_dir, unique_name) class VibeVoiceDemo: def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): """Initialize the VibeVoice demo with model loading.""" self.model_path = model_path self.device = device self.inference_steps = inference_steps self.is_generating = False # Track generation state self.stop_generation = False # Flag to stop generation self.load_model() self.setup_voice_presets() self.load_example_scripts() # Load example scripts def load_model(self): """Load the VibeVoice model and processor.""" print(f"Loading processor & model from {self.model_path}") self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) if self.device == "cuda": self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( self.model_path, torch_dtype=torch.bfloat16, device_map=self.device, ) else: self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( self.model_path, torch_dtype=torch.float32, # Use float32 for CPU device_map="cpu", ) self.model.eval() self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( self.model.model.noise_scheduler.config, algorithm_type='sde-dpmsolver++', beta_schedule='squaredcos_cap_v2' ) self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) if hasattr(self.model.model, 'language_model'): print(f"Language model attention: {self.model.model.language_model.config._attn_implementation}") def setup_voice_presets(self): """Setup voice presets by scanning the voices directory.""" voices_dir = os.path.join(os.path.dirname(__file__), "voices") if not os.path.exists(voices_dir): print(f"Warning: Voices directory not found at {voices_dir}, creating it.") os.makedirs(voices_dir, exist_ok=True) self.voice_presets = {} audio_files = [f for f in os.listdir(voices_dir) if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))] for audio_file in audio_files: name = os.path.splitext(audio_file)[0] full_path = os.path.join(voices_dir, audio_file) self.voice_presets[name] = full_path self.voice_presets = dict(sorted(self.voice_presets.items())) self.available_voices = {name: path for name, path in self.voice_presets.items() if os.path.exists(path)} if not self.available_voices: print("Warning: No voice presets found.") print(f"Found {len(self.available_voices)} voice files in {voices_dir}") def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: """Read and preprocess audio file.""" try: wav, sr = sf.read(audio_path) if len(wav.shape) > 1: wav = np.mean(wav, axis=1) if sr != target_sr: wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) return wav except Exception as e: print(f"Error reading audio {audio_path}: {e}") return np.array([]) def trim_silence_from_numpy(self, audio_np: np.ndarray, sample_rate: int, silence_thresh: int = -45, min_silence_len: int = 100, keep_silence: int = 50) -> np.ndarray: """Removes silence from a NumPy audio array using pydub.""" audio_int16 = (audio_np * 32767).astype(np.int16) sound = AudioSegment( data=audio_int16.tobytes(), sample_width=audio_int16.dtype.itemsize, frame_rate=sample_rate, channels=1 ) audio_chunks = split_on_silence( sound, min_silence_len=min_silence_len, silence_thresh=silence_thresh, keep_silence=keep_silence ) if not audio_chunks: return np.array([0.0], dtype=np.float32) combined = sum(audio_chunks) samples = np.array(combined.get_array_of_samples()) trimmed_audio_np = samples.astype(np.float32) / 32767.0 return trimmed_audio_np def generate_podcast_with_timestamps(self, num_speakers: int, script: str, speaker_1: str = None, speaker_2: str = None, speaker_3: str = None, speaker_4: str = None, cfg_scale: float = 1.3, remove_silence: bool = False, progress=gr.Progress()): try: self.stop_generation = False self.is_generating = True # --- Input Validation and Setup --- if not script.strip(): raise gr.Error("Error: Please provide a script.") script = script.replace("’", "'") if not 1 <= num_speakers <= 4: raise gr.Error("Error: Number of speakers must be between 1 and 4.") selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers] for i, speaker in enumerate(selected_speakers): if not speaker or speaker not in self.available_voices: raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.") voice_samples = [self.read_audio(self.available_voices[name]) for name in selected_speakers] if any(len(vs) == 0 for vs in voice_samples): raise gr.Error("Error: Failed to load one or more audio files.") lines = script.strip().split('\n') formatted_script_lines = [] for line in lines: line = line.strip() if not line: continue if re.match(r'Speaker\s*\d+:', line, re.IGNORECASE): formatted_script_lines.append(line) else: speaker_id = len(formatted_script_lines) % num_speakers formatted_script_lines.append(f"Speaker {speaker_id}: {line}") if not formatted_script_lines: raise gr.Error("Error: Script is empty after formatting.") # --- Prepare for Generation --- timestamps = {} current_time = 0.0 sample_rate = 24000 total_lines = len(formatted_script_lines) base_filename = generate_file_name(formatted_script_lines[0]) final_audio_path = base_filename + ".wav" final_json_path = base_filename + ".json" # --- Open file and write chunks sequentially (MEMORY EFFICIENT) --- with sf.SoundFile(final_audio_path, 'w', samplerate=sample_rate, channels=1, subtype='PCM_16') as audio_file: for i, line in enumerate(formatted_script_lines): if self.stop_generation: break progress(i / total_lines, desc=f"Generating line {i+1}/{total_lines}") match = re.match(r'Speaker\s*(\d+):\s*(.*)', line, re.IGNORECASE) if not match: continue speaker_idx = int(match.group(1)) - 1 text_content = match.group(2).strip() if speaker_idx < 0 or speaker_idx >= len(voice_samples): continue inputs = self.processor( text=[line], voice_samples=[voice_samples], padding=True, return_tensors="pt" ) output_waveform: VibeVoiceGenerationOutput = self.model.generate( **inputs, max_new_tokens=None, cfg_scale=cfg_scale, tokenizer=self.processor.tokenizer, generation_config={'do_sample': False}, verbose=False, refresh_negative=True ) audio_np = output_waveform.speech_outputs[0].cpu().float().numpy().squeeze() # NEW FEATURE: Remove silence if enabled if remove_silence: audio_np = self.trim_silence_from_numpy(audio_np, sample_rate) duration = len(audio_np) / sample_rate audio_int16 = (audio_np * 32767).astype(np.int16) audio_file.write(audio_int16) timestamps[str(i + 1)] = { "text": text_content, "speaker_id": speaker_idx, "start": current_time, "end": current_time + duration } current_time += duration # --- Finalize and Save JSON --- progress(1.0, desc="Saving timestamp file...") with open(final_json_path, "w") as f: json.dump(timestamps, f, indent=2) try: drive_save(final_audio_path) drive_save(final_json_path) except Exception as e: print(f"Error saving files to Google Drive: {e}") print(f"\n✨ Generation successful!\n🎵 Audio: {final_audio_path}\n📄 Timestamps: {final_json_path}\n") self.is_generating = False return final_audio_path, final_audio_path, final_json_path, gr.update(visible=True), gr.update(visible=False) except Exception as e: self.is_generating = False print(f"❌ An unexpected error occurred: {str(e)}") traceback.print_exc() return None, None, None, gr.update(visible=True), gr.update(visible=False) def stop_audio_generation(self): if self.is_generating: self.stop_generation = True print("🛑 Audio generation stop requested") def load_example_scripts(self): examples_dir = os.path.join(os.path.dirname(__file__), "text_examples") self.example_scripts = [] if not os.path.exists(examples_dir): return txt_files = sorted([f for f in os.listdir(examples_dir) if f.lower().endswith('.txt')]) for txt_file in txt_files: try: with open(os.path.join(examples_dir, txt_file), 'r', encoding='utf-8') as f: script = f.read().strip() if script: self.example_scripts.append([self._get_num_speakers_from_script(script), script]) except Exception as e: print(f"Error loading example {txt_file}: {e}") def _get_num_speakers_from_script(self, script: str) -> int: speakers = set(re.findall(r'^Speaker\s+(\d+)\s*:', script, re.MULTILINE | re.IGNORECASE)) return max(int(s) for s in speakers) if speakers else 1 def create_demo_interface(demo_instance: VibeVoiceDemo): with gr.Blocks( title="VibeVoice AI Podcast Generator" ) as interface: gr.HTML("""
Generate Long-form Multi-speaker AI Podcasts with VibeVoice
🥳 Run on Google ColabPaste the prompt into any LLM, and customize the propmt if you want.