File size: 5,731 Bytes
b7658fb
6b64262
3f97053
4324db0
3f97053
1cb26d6
c1b8cf8
 
b2be989
1cb26d6
4324db0
6b64262
 
 
3f97053
6b64262
 
 
 
 
 
 
 
 
 
 
 
 
8a1f431
 
6b64262
 
4324db0
8a1f431
6b64262
 
 
 
 
 
 
 
 
 
 
 
8a1f431
4324db0
8a1f431
21a831f
4324db0
6a0b1a5
4324db0
6d6254a
6a0b1a5
fb12e2c
 
 
 
 
 
 
6a0b1a5
6d6254a
 
 
 
4324db0
c1b8cf8
 
 
6a0b1a5
 
 
 
 
 
fb12e2c
 
 
 
 
 
6a0b1a5
c1b8cf8
6d6254a
 
6a0b1a5
 
 
 
 
 
 
fb12e2c
6d6254a
 
6a0b1a5
 
 
 
 
 
6d6254a
 
c1b8cf8
 
 
 
6a0b1a5
 
fb12e2c
 
 
 
 
 
6a0b1a5
c1b8cf8
6a0b1a5
 
 
 
 
6d6254a
 
21a831f
8a1f431
c1b8cf8
4324db0
 
 
 
21a831f
4324db0
6d6254a
6b64262
8a1f431
6b64262
 
4324db0
6b64262
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import subprocess
import sys
from pathlib import Path

# --- 0. Hardcoded Toggle for Execution Environment ---
# Set this to True to use Hugging Face ZeroGPU (recommended)
# Set this to False to use the slower, pure CPU environment
USE_ZEROGPU = True 

# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
    print("Cloning the VibeVoice repository...")
    try:
        subprocess.run(
            ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
            check=True,
            capture_output=True,
            text=True
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error cloning repository: {e.stderr}")
        sys.exit(1)
else:
    print("Repository already exists. Skipping clone.")

# --- 2. Install the VibeVoice Package ---
# Note: Other dependencies are installed via requirements.txt
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")

print("Installing the VibeVoice package in editable mode...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-e", "."],
        check=True,
        capture_output=True,
        text=True
    )
    print("Package installed successfully.")
except subprocess.CalledProcessError as e:
    print(f"Error installing package: {e.stderr}")
    sys.exit(1)

# --- 3. Modify the demo script to be environment-aware ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Reading {demo_script_path} to apply environment-specific modifications...")

try:
    modified_content = demo_script_path.read_text()

    # Define the original model loading block using a list of lines for robustness.
    original_model_lines = [
        '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
        '            self.model_path,',
        '            torch_dtype=torch.bfloat16,',
        "            device_map='cuda',",
        '            attn_implementation="flash_attention_2",',
        '        )'
    ]
    original_model_block = "\n".join(original_model_lines)
    
    # More robustly define the generation method signature to patch.
    # We only need the first line to find our target.
    original_method_signature = "    def generate_podcast_streaming(self,"

    if USE_ZEROGPU:
        print("Optimizing for ZeroGPU execution...")
        
        # Add 'import spaces' if it's not already there.
        if "import spaces" not in modified_content:
            modified_content = "import spaces\n" + modified_content

        # New block for ZeroGPU model loading: remove `attn_implementation`.
        replacement_model_lines_gpu = [
            '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
            '            self.model_path,',
            '            torch_dtype=torch.bfloat16,',
            "            device_map='cuda',",
            '        )'
        ]
        replacement_model_block_gpu = "\n".join(replacement_model_lines_gpu)
        
        # Add the @spaces.GPU decorator *with correct indentation* before the method.
        replacement_method_signature_gpu = "    @spaces.GPU(duration=120)\n" + original_method_signature

        # --- Apply Patches for GPU ---

        # Patch 1: Decorate the generation method
        if original_method_signature in modified_content:
            modified_content = modified_content.replace(original_method_signature, replacement_method_signature_gpu)
            print("Successfully applied GPU decorator to the generation method.")
        else:
            print("\033[91mError: Could not find the generation method signature to apply the GPU decorator.\033[0m")
            sys.exit(1)

        # Patch 2: Modify the model loading
        if original_model_block in modified_content:
            modified_content = modified_content.replace(original_model_block, replacement_model_block_gpu)
            print("Successfully patched the model loading block for ZeroGPU.")
        else:
            print("\033[91mError: The original model loading block was not found. Patching may have failed.\033[0m")
            sys.exit(1)

    else: # Pure CPU execution
        print("Modifying for pure CPU execution...")
        
        # New block for CPU: Use float32 and map to CPU.
        replacement_model_lines_cpu = [
            '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
            '            self.model_path,',
            '            torch_dtype=torch.float32,  # Use float32 for CPU',
            '            device_map="cpu",',
            '        )'
        ]
        replacement_model_block_cpu = "\n".join(replacement_model_lines_cpu)
        
        # Apply patch for CPU
        if original_model_block in modified_content:
            modified_content = modified_content.replace(original_model_block, replacement_model_block_cpu)
            print("Script modified for CPU successfully.")
        else:
            print("\033[91mError: The original model loading block was not found. Patching may have failed.\03-3[0m")
            sys.exit(1)

    # Write the dynamically modified content back to the demo file
    demo_script_path.write_text(modified_content)

except Exception as e:
    print(f"An error occurred while modifying the script: {e}")
    sys.exit(1)

# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"

# Construct the command to run the modified demo script
command = [
    "python",
    str(demo_script_path),
    "--model_path",
    model_id,
    "--share"
]

print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)