File size: 5,095 Bytes
b7658fb
6b64262
3f97053
4324db0
3f97053
1cb26d6
c1b8cf8
 
b2be989
1cb26d6
4324db0
6b64262
 
 
3f97053
6b64262
 
 
 
 
 
 
 
 
 
 
 
 
8a1f431
 
6b64262
 
4324db0
8a1f431
6b64262
 
 
 
 
 
 
 
 
 
 
 
8a1f431
4324db0
8a1f431
21a831f
4324db0
6a0b1a5
4324db0
c1b8cf8
212332b
c1b8cf8
6a0b1a5
 
 
 
212332b
 
 
c1b8cf8
212332b
6d6254a
6a0b1a5
 
212332b
 
6a0b1a5
 
 
212332b
fb12e2c
6d6254a
 
6a0b1a5
c1b8cf8
 
 
212332b
 
 
 
 
 
 
 
 
 
 
6a0b1a5
 
fb12e2c
 
 
 
 
 
6a0b1a5
c1b8cf8
6a0b1a5
 
 
 
 
212332b
6d6254a
21a831f
8a1f431
c1b8cf8
4324db0
 
 
 
21a831f
4324db0
6d6254a
6b64262
8a1f431
6b64262
 
4324db0
6b64262
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import subprocess
import sys
from pathlib import Path

# --- 0. Hardcoded Toggle for Execution Environment ---
# Set this to True to use Hugging Face ZeroGPU (recommended)
# Set this to False to use the slower, pure CPU environment
USE_ZEROGPU = True 

# --- 1. Clone the VibeVoice Repository ---
repo_dir = "VibeVoice"
if not os.path.exists(repo_dir):
    print("Cloning the VibeVoice repository...")
    try:
        subprocess.run(
            ["git", "clone", "https://github.com/microsoft/VibeVoice.git"],
            check=True,
            capture_output=True,
            text=True
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error cloning repository: {e.stderr}")
        sys.exit(1)
else:
    print("Repository already exists. Skipping clone.")

# --- 2. Install the VibeVoice Package ---
# Note: Other dependencies are installed via requirements.txt
os.chdir(repo_dir)
print(f"Changed directory to: {os.getcwd()}")

print("Installing the VibeVoice package in editable mode...")
try:
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-e", "."],
        check=True,
        capture_output=True,
        text=True
    )
    print("Package installed successfully.")
except subprocess.CalledProcessError as e:
    print(f"Error installing package: {e.stderr}")
    sys.exit(1)

# --- 3. Modify the demo script to be environment-aware ---
demo_script_path = Path("demo/gradio_demo.py")
print(f"Reading {demo_script_path} to apply environment-specific modifications...")

try:
    modified_content = demo_script_path.read_text()

    if USE_ZEROGPU:
        print("Configuring for ZeroGPU execution while keeping Flash Attention...")
        
        # Add 'import spaces' if it's not already there.
        if "import spaces" not in modified_content:
            modified_content = "import spaces\n" + modified_content

        # Define the generation method signature to add the decorator to.
        # We target only the first line for robustness.
        original_method_signature = "    def generate_podcast_streaming(self,"
        
        # Define the replacement with the correctly indented decorator.
        replacement_method_signature_gpu = "    @spaces.GPU(duration=120)\n" + original_method_signature

        # --- Apply Patches for GPU ---
        # The only change needed is to add the decorator. We will NOT modify the
        # from_pretrained call, leaving attn_implementation="flash_attention_2" in place.
        if original_method_signature in modified_content:
            modified_content = modified_content.replace(original_method_signature, replacement_method_signature_gpu)
            print("Successfully applied GPU decorator to the generation method.")
            print("Model loading block remains unchanged to explicitly use Flash Attention.")
        else:
            print("\033[91mError: Could not find the generation method signature to apply the GPU decorator.\033[0m")
            sys.exit(1)

    else: # Pure CPU execution
        print("Modifying for pure CPU execution...")
        
        # For the CPU path, we still need to replace the entire CUDA-specific block.
        original_model_lines = [
            '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
            '            self.model_path,',
            '            torch_dtype=torch.bfloat16,',
            "            device_map='cuda',",
            '            attn_implementation="flash_attention_2",',
            '        )'
        ]
        original_model_block = "\n".join(original_model_lines)
        
        # New block for CPU: Use float32 and map to CPU.
        replacement_model_lines_cpu = [
            '        self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
            '            self.model_path,',
            '            torch_dtype=torch.float32,  # Use float32 for CPU',
            '            device_map="cpu",',
            '        )'
        ]
        replacement_model_block_cpu = "\n".join(replacement_model_lines_cpu)
        
        # Apply patch for CPU
        if original_model_block in modified_content:
            modified_content = modified_content.replace(original_model_block, replacement_model_block_cpu)
            print("Script modified for CPU successfully.")
        else:
            print("\033[91mError: The original model loading block was not found for CPU patching.\033[0m")
            sys.exit(1)

    # Write the dynamically modified content back to the demo file
    demo_script_path.write_text(modified_content)

except Exception as e:
    print(f"An error occurred while modifying the script: {e}")
    sys.exit(1)

# --- 4. Launch the Gradio Demo ---
model_id = "microsoft/VibeVoice-1.5B"

# Construct the command to run the modified demo script
command = [
    "python",
    str(demo_script_path),
    "--model_path",
    model_id,
    "--share"
]

print(f"Launching Gradio demo with command: {' '.join(command)}")
subprocess.run(command)