broadfield-dev's picture
Update app.py
c5f9e51 verified
import os
import subprocess
import sys
from pathlib import Path
# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
print("Cloning the VibeVoice repository...")
try:
subprocess.run(
["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
check=True, capture_output=True, text=True
)
print("Repository cloned successfully.")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr}")
sys.exit(1)
else:
print("Repository already exists. Skipping clone.")
# --- 2. Install Dependencies ---
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")
print("Installing the VibeVoice package in editable mode...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", "."],
check=True, capture_output=True, text=True
)
print("Package installed successfully.")
except subprocess.CalledProcessError as e:
print(f"Error installing package: {e.stderr}")
sys.exit(1)
# --- 3. Refactor the demo script using a direct replacement strategy ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
try:
with open(demo_script_path, 'r') as f:
modified_content = f.read()
# --- Add 'import spaces' at the top ---
if "import spaces" not in modified_content:
modified_content = "import spaces\n" + modified_content
# --- Patch 1: Defer model loading in __init__ ---
original_init_call = " self.load_model()"
replacement_init_block = (
" # self.load_model() # Patched: Defer model loading\n"
" self.model = None\n"
" self.processor = None"
)
if original_init_call in modified_content:
modified_content = modified_content.replace(original_init_call, replacement_init_block, 1)
print("Successfully patched __init__ to prevent startup model load.")
else:
print(f"\033[91mError: Could not find '{original_init_call}' to patch. Startup patch failed.\033[0m")
sys.exit(1)
# --- Patch 2: Add decorator and lazy-loading logic to the generation method ---
# Define the exact block to find, spanning the full method signature down to the 'try:'.
# This is sensitive to whitespace but is the most direct way to replace.
original_method_header = """ def generate_podcast_streaming(self,
num_speakers: int,
script: str,
speaker_1: str = None,
speaker_2: str = None,
speaker_3: str = None,
speaker_4: str = None,
cfg_scale: float = 1.3) -> Iterator[tuple]:
try:"""
# Define the full replacement block with correct indentation.
replacement_method_header = """ @spaces.GPU(duration=120)
def generate_podcast_streaming(self,
num_speakers: int,
script: str,
speaker_1: str = None,
speaker_2: str = None,
speaker_3: str = None,
speaker_4: str = None,
cfg_scale: float = 1.3) -> Iterator[tuple]:
# Patched: Lazy-load model and processor on the GPU worker
if self.model is None or self.processor is None:
print("Loading processor & model for the first time on GPU worker...")
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16, # Use 16-bit precision for quality
device_map="auto",
)
self.model.eval()
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
self.model.model.noise_scheduler.config,
algorithm_type='sde-dpmsolver++',
beta_schedule='squaredcos_cap_v2'
)
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
print("Model and processor loaded successfully on GPU worker.")
try:"""
if original_method_header in modified_content:
modified_content = modified_content.replace(original_method_header, replacement_method_header, 1)
print("Successfully patched generation method for lazy loading.")
else:
print(f"\033[91mError: Could not find the method definition for 'generate_podcast_streaming' to patch. This is likely due to a whitespace mismatch. Please check the demo script.\033[0m")
sys.exit(1)
# --- Write the modified content back to the file ---
with open(demo_script_path, 'w') as f:
f.write(modified_content)
print("Script patching complete.")
except Exception as e:
print(f"An error occurred while modifying the script: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"
command = ["python", str(demo_script_path), "--model_path", model_id, "--share"]
print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)