broadfield-dev's picture
Update app.py
212332b verified
raw
history blame
5.1 kB
import os
import subprocess
import sys
from pathlib import Path
# --- 0. Hardcoded Toggle for Execution Environment ---
# Set this to True to use Hugging Face ZeroGPU (recommended)
# Set this to False to use the slower, pure CPU environment
USE_ZEROGPU = True
# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
print("Cloning the VibeVoice repository...")
try:
subprocess.run(
["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
check=True,
capture_output=True,
text=True
)
print("Repository cloned successfully.")
except subprocess.CalledProcessError as e:
print(f"Error cloning repository: {e.stderr}")
sys.exit(1)
else:
print("Repository already exists. Skipping clone.")
# --- 2. Install the VibeVoice Package ---
# Note: Other dependencies are installed via requirements.txt
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")
print("Installing the VibeVoice package in editable mode...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", "."],
check=True,
capture_output=True,
text=True
)
print("Package installed successfully.")
except subprocess.CalledProcessError as e:
print(f"Error installing package: {e.stderr}")
sys.exit(1)
# --- 3. Modify the demo script to be environment-aware ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Reading {demo_script_path} to apply environment-specific modifications...")
try:
modified_content = demo_script_path.read_text()
if USE_ZEROGPU:
print("Configuring for ZeroGPU execution while keeping Flash Attention...")
# Add 'import spaces' if it's not already there.
if "import spaces" not in modified_content:
modified_content = "import spaces\n" + modified_content
# Define the generation method signature to add the decorator to.
# We target only the first line for robustness.
original_method_signature = " def generate_podcast_streaming(self,"
# Define the replacement with the correctly indented decorator.
replacement_method_signature_gpu = " @spaces.GPU(duration=120)\n" + original_method_signature
# --- Apply Patches for GPU ---
# The only change needed is to add the decorator. We will NOT modify the
# from_pretrained call, leaving attn_implementation="flash_attention_2" in place.
if original_method_signature in modified_content:
modified_content = modified_content.replace(original_method_signature, replacement_method_signature_gpu)
print("Successfully applied GPU decorator to the generation method.")
print("Model loading block remains unchanged to explicitly use Flash Attention.")
else:
print("\033[91mError: Could not find the generation method signature to apply the GPU decorator.\033[0m")
sys.exit(1)
else: # Pure CPU execution
print("Modifying for pure CPU execution...")
# For the CPU path, we still need to replace the entire CUDA-specific block.
original_model_lines = [
' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
' self.model_path,',
' torch_dtype=torch.bfloat16,',
" device_map='cuda',",
' attn_implementation="flash_attention_2",',
' )'
]
original_model_block = "\n".join(original_model_lines)
# New block for CPU: Use float32 and map to CPU.
replacement_model_lines_cpu = [
' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
' self.model_path,',
' torch_dtype=torch.float32, # Use float32 for CPU',
' device_map="cpu",',
' )'
]
replacement_model_block_cpu = "\n".join(replacement_model_lines_cpu)
# Apply patch for CPU
if original_model_block in modified_content:
modified_content = modified_content.replace(original_model_block, replacement_model_block_cpu)
print("Script modified for CPU successfully.")
else:
print("\033[91mError: The original model loading block was not found for CPU patching.\033[0m")
sys.exit(1)
# Write the dynamically modified content back to the demo file
demo_script_path.write_text(modified_content)
except Exception as e:
print(f"An error occurred while modifying the script: {e}")
sys.exit(1)
# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"
# Construct the command to run the modified demo script
command = [
"python",
str(demo_script_path),
"--model_path",
model_id,
"--share"
]
print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)