File size: 4,718 Bytes
b7658fb
6b64262
3f97053
4324db0
3f97053
1cb26d6
c1b8cf8
 
1cb26d6
 
4324db0
6b64262
 
 
3f97053
6b64262
 
 
 
 
 
 
 
 
 
 
 
 
1cb26d6
6b64262
 
4324db0
1cb26d6
6b64262
 
 
 
 
 
 
 
 
 
 
 
 
c1b8cf8
1cb26d6
 
 
 
c1b8cf8
1cb26d6
 
 
 
 
 
 
 
 
 
4324db0
1cb26d6
21a831f
4324db0
 
 
c1b8cf8
 
 
4324db0
 
 
 
 
 
c1b8cf8
 
 
 
 
 
1cb26d6
 
 
 
c1b8cf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4324db0
 
 
 
c1b8cf8
 
 
 
21a831f
c1b8cf8
 
4324db0
 
 
 
21a831f
4324db0
6b64262
 
21a831f
6b64262
 
4324db0
6b64262
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import subprocess
import sys
from pathlib import Path

# --- 0. Hardcoded Toggle for Execution Environment ---
# Set this to True to use Hugging Face ZeroGPU (recommended)
# Set this to False to use the slower, pure CPU environment
USE_ZEROGPU = True

# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
    print("Cloning the VibeVoice repository...")
    try:
        subprocess.run(
            ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
            check=True,
            capture_output=True,
            text=True
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error cloning repository: {e.stderr}")
        sys.exit(1)
else:
    print("Repository already exists. Skipping clone.")

# --- 2. Install Dependencies ---
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")

# Install the main package
print("Installing the VibeVoice package...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-e", "."],
        check=True,
        capture_output=True,
        text=True
    )
    print("Package installed successfully.")
except subprocess.CalledProcessError as e:
    print(f"Error installing package: {e.stderr}")
    sys.exit(1)

# Install 'spaces' if using ZeroGPU, as it's required for the decorator
if USE_ZEROGPU:
    print("Installing the 'spaces' library for ZeroGPU...")
    try:
        subprocess.run(
            [sys.executable, "-m", "pip", "install", "spaces"],
            check=True,
            capture_output=True,
            text=True
        )
        print("'spaces' library installed successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error installing 'spaces' library: {e.stderr}")
        sys.exit(1)

# --- 3. Modify the demo script based on the toggle ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Reading {demo_script_path}...")

try:
    file_content = demo_script_path.read_text()

    # Define the original GPU-specific model loading block we want to replace
    # This block is problematic because it hardcodes FlashAttention
    original_block = """        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            device_map='cuda',
            attn_implementation="flash_attention_2",
        )"""

    if USE_ZEROGPU:
        print("Optimizing for ZeroGPU execution...")
        
        # New block for ZeroGPU: We remove the problematic flash_attention line.
        # Transformers will automatically use the best available attention mechanism.
        replacement_block_gpu = """        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            self.model_path,
            torch_dtype=torch.bfloat16,
            device_map='cuda',
        )"""
        
        # Add 'import spaces' at the beginning of the file
        modified_content = "import spaces\n" + file_content
        
        # Decorate the main class with @spaces.GPU to request a GPU
        modified_content = modified_content.replace(
            "class VibeVoiceGradioInterface:",
            "@spaces.GPU(duration=120)\nclass VibeVoiceGradioInterface:"
        )
        
        # Replace the model loading block
        modified_content = modified_content.replace(original_block, replacement_block_gpu)
        print("Script modified for ZeroGPU successfully.")

    else: # Pure CPU execution
        print("Modifying for pure CPU execution...")
        
        # New block for CPU: Use float32 and map directly to CPU.
        # FlashAttention is not compatible with CPU.
        replacement_block_cpu = """        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
            self.model_path,
            torch_dtype=torch.float32,  # Use float32 for CPU
            device_map="cpu",
        )"""
        
        # Replace the model loading block
        modified_content = file_content.replace(original_block, replacement_block_cpu)
        print("Script modified for CPU successfully.")

    # Write the modified content back to the file
    demo_script_path.write_text(modified_content)

except Exception as e:
    print(f"An error occurred while modifying the script: {e}")
    sys.exit(1)

# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"

# Construct the command as specified in the README
command = [
    "python",
    str(demo_script_path),
    "--model_path",
    model_id,
    "--share"
]

print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)