File size: 5,226 Bytes
b7658fb
6b64262
3f97053
4324db0
3f97053
1cb26d6
c1b8cf8
 
b2be989
1cb26d6
4324db0
6b64262
 
 
3f97053
6b64262
 
 
 
 
 
 
 
 
 
 
 
 
8a1f431
 
6b64262
 
4324db0
8a1f431
6b64262
 
 
 
 
 
 
 
 
 
 
 
8a1f431
4324db0
8a1f431
21a831f
4324db0
 
 
fb12e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4324db0
c1b8cf8
 
 
8a1f431
fb12e2c
 
 
 
 
 
 
 
c1b8cf8
8a1f431
fb12e2c
 
 
 
c1b8cf8
8a1f431
fb12e2c
 
 
 
 
c1b8cf8
 
 
 
 
 
 
 
8a1f431
fb12e2c
 
 
 
 
 
 
 
c1b8cf8
8a1f431
c1b8cf8
 
21a831f
8a1f431
c1b8cf8
4324db0
 
 
 
21a831f
4324db0
6b64262
 
8a1f431
6b64262
 
4324db0
6b64262
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import subprocess
import sys
from pathlib import Path

# --- 0. Hardcoded Toggle for Execution Environment ---
# Set this to True to use Hugging Face ZeroGPU (recommended)
# Set this to False to use the slower, pure CPU environment
USE_ZEROGPU = True 

# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
    print("Cloning the VibeVoice repository...")
    try:
        subprocess.run(
            ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
            check=True,
            capture_output=True,
            text=True
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error cloning repository: {e.stderr}")
        sys.exit(1)
else:
    print("Repository already exists. Skipping clone.")

# --- 2. Install the VibeVoice Package ---
# Note: Other dependencies are installed via requirements.txt
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")

print("Installing the VibeVoice package in editable mode...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-e", "."],
        check=True,
        capture_output=True,
        text=True
    )
    print("Package installed successfully.")
except subprocess.CalledProcessError as e:
    print(f"Error installing package: {e.stderr}")
    sys.exit(1)

# --- 3. Modify the demo script to be environment-aware ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Reading {demo_script_path} to apply environment-specific modifications...")

try:
    file_content = demo_script_path.read_text()

    # Define the original model loading block using a list of lines for robustness.
    # This avoids issues with indentation in multi-line string literals.
    original_lines = [
        '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
        '            self.model_path,',
        '            torch_dtype=torch.bfloat16,',
        "            device_map='cuda',",
        '            attn_implementation="flash_attention_2",',
        '        )'
    ]
    original_block = "\n".join(original_lines)

    # Check if the block to be patched exists in the file
    if original_block not in file_content:
        print("\033[91mError: The original code block to be patched was not found.\033[0m")
        print("The demo script may have changed, or there might be a whitespace mismatch.")
        print("Please verify the contents of demo/gradio_demo.py.")
        sys.exit(1)

    if USE_ZEROGPU:
        print("Optimizing for ZeroGPU execution...")
        
        # New block for ZeroGPU: We remove the problematic `attn_implementation` line.
        replacement_lines_gpu = [
            '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
            '            self.model_path,',
            '            torch_dtype=torch.bfloat16,',
            "            device_map='cuda',",
            '        )'
        ]
        replacement_block_gpu = "\n".join(replacement_lines_gpu)
        
        # Add 'import spaces' at the beginning of the file for the @spaces.GPU decorator
        if "import spaces" not in file_content:
             modified_content = "import spaces\n" + file_content
        else:
             modified_content = file_content
        
        # Decorate the main interface class to request a GPU from the Spaces infrastructure
        if "@spaces.GPU" not in modified_content:
            modified_content = modified_content.replace(
                "class VibeVoiceDemo:",
                "@spaces.GPU(duration=120)\nclass VibeVoiceDemo:"
            )
        
        # Replace the model loading block
        modified_content = modified_content.replace(original_block, replacement_block_gpu)
        print("Script modified for ZeroGPU successfully.")

    else: # Pure CPU execution
        print("Modifying for pure CPU execution...")
        
        # New block for CPU: Use float32 and map directly to the CPU.
        replacement_lines_cpu = [
            '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
            '            self.model_path,',
            '            torch_dtype=torch.float32,  # Use float32 for CPU',
            '            device_map="cpu",',
            '        )'
        ]
        replacement_block_cpu = "\n".join(replacement_lines_cpu)
        
        # Replace the original model loading block with the CPU version
        modified_content = file_content.replace(original_block, replacement_block_cpu)
        print("Script modified for CPU successfully.")

    # Write the dynamically modified content back to the demo file
    demo_script_path.write_text(modified_content)

except Exception as e:
    print(f"An error occurred while modifying the script: {e}")
    sys.exit(1)

# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"

# Construct the command to run the modified demo script
command = [
    "python",
    str(demo_script_path),
    "--model_path",
    model_id,
    "--share"
]

print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)