Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import subprocess | |
import sys | |
from pathlib import Path | |
# --- 1. Clone the VibeVoice Repository --- | |
repo_dir = "VibeVoice" | |
if not os.path.exists(repo_dir): | |
print("Cloning the VibeVoice repository...") | |
try: | |
subprocess.run( | |
["git", "clone", "https://github.com/microsoft/VibeVoice.git"], | |
check=True, capture_output=True, text=True | |
) | |
print("Repository cloned successfully.") | |
except subprocess.CalledProcessError as e: | |
print(f"Error cloning repository: {e.stderr}") | |
sys.exit(1) | |
else: | |
print("Repository already exists. Skipping clone.") | |
# --- 2. Install Dependencies --- | |
os.chdir(repo_dir) | |
print(f"Changed directory to: {os.getcwd()}") | |
print("Installing the VibeVoice package in editable mode...") | |
try: | |
subprocess.run( | |
[sys.executable, "-m", "pip", "install", "-e", "."], | |
check=True, capture_output=True, text=True | |
) | |
print("Package installed successfully.") | |
except subprocess.CalledProcessError as e: | |
print(f"Error installing package: {e.stderr}") | |
sys.exit(1) | |
# --- 3. Refactor the demo script using a direct replacement strategy --- | |
demo_script_path = Path("demo/gradio_demo.py") | |
print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...") | |
try: | |
with open(demo_script_path, 'r') as f: | |
modified_content = f.read() | |
# --- Add 'import spaces' at the top --- | |
if "import spaces" not in modified_content: | |
modified_content = "import spaces\n" + modified_content | |
# --- Patch 1: Defer model loading in __init__ --- | |
original_init_call = " self.load_model()" | |
replacement_init_block = ( | |
" # self.load_model() # Patched: Defer model loading\n" | |
" self.model = None\n" | |
" self.processor = None" | |
) | |
if original_init_call in modified_content: | |
modified_content = modified_content.replace(original_init_call, replacement_init_block, 1) | |
print("Successfully patched __init__ to prevent startup model load.") | |
else: | |
print(f"\033[91mError: Could not find '{original_init_call}' to patch. Startup patch failed.\033[0m") | |
sys.exit(1) | |
# --- Patch 2: Add decorator and lazy-loading logic to the generation method --- | |
# Define the exact block to find, spanning the full method signature down to the 'try:'. | |
# This is sensitive to whitespace but is the most direct way to replace. | |
original_method_header = """ def generate_podcast_streaming(self, | |
num_speakers: int, | |
script: str, | |
speaker_1: str = None, | |
speaker_2: str = None, | |
speaker_3: str = None, | |
speaker_4: str = None, | |
cfg_scale: float = 1.3) -> Iterator[tuple]: | |
try:""" | |
# Define the full replacement block with correct indentation. | |
replacement_method_header = """ @spaces.GPU(duration=120) | |
def generate_podcast_streaming(self, | |
num_speakers: int, | |
script: str, | |
speaker_1: str = None, | |
speaker_2: str = None, | |
speaker_3: str = None, | |
speaker_4: str = None, | |
cfg_scale: float = 1.3) -> Iterator[tuple]: | |
# Patched: Lazy-load model and processor on the GPU worker | |
if self.model is None or self.processor is None: | |
print("Loading processor & model for the first time on GPU worker...") | |
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) | |
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
self.model_path, | |
torch_dtype=torch.bfloat16, # Use 16-bit precision for quality | |
device_map="auto", | |
) | |
self.model.eval() | |
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( | |
self.model.model.noise_scheduler.config, | |
algorithm_type='sde-dpmsolver++', | |
beta_schedule='squaredcos_cap_v2' | |
) | |
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
print("Model and processor loaded successfully on GPU worker.") | |
try:""" | |
if original_method_header in modified_content: | |
modified_content = modified_content.replace(original_method_header, replacement_method_header, 1) | |
print("Successfully patched generation method for lazy loading.") | |
else: | |
print(f"\033[91mError: Could not find the method definition for 'generate_podcast_streaming' to patch. This is likely due to a whitespace mismatch. Please check the demo script.\033[0m") | |
sys.exit(1) | |
# --- Write the modified content back to the file --- | |
with open(demo_script_path, 'w') as f: | |
f.write(modified_content) | |
print("Script patching complete.") | |
except Exception as e: | |
print(f"An error occurred while modifying the script: {e}") | |
import traceback | |
traceback.print_exc() | |
sys.exit(1) | |
# --- 4. Launch the Gradio Demo --- | |
model_id = "microsoft/VibeVoice-1.5B" | |
command = ["python", str(demo_script_path), "--model_path", model_id, "--share"] | |
print(f"Launching Gradio demo with command: {' '.join(command)}") | |
subprocess.run(command) |