broadfield-dev commited on
Commit
6a0b1a5
·
verified ·
1 Parent(s): fb12e2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -36
app.py CHANGED
@@ -49,11 +49,12 @@ demo_script_path = Path("demo/gradio_demo.py")
49
  print(f"Reading {demo_script_path} to apply environment-specific modifications...")
50
 
51
  try:
52
- file_content = demo_script_path.read_text()
53
 
54
- # Define the original model loading block using a list of lines for robustness.
55
- # This avoids issues with indentation in multi-line string literals.
56
- original_lines = [
 
57
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
58
  ' self.model_path,',
59
  ' torch_dtype=torch.bfloat16,',
@@ -61,61 +62,77 @@ try:
61
  ' attn_implementation="flash_attention_2",',
62
  ' )'
63
  ]
64
- original_block = "\n".join(original_lines)
65
 
66
- # Check if the block to be patched exists in the file
67
- if original_block not in file_content:
68
- print("\033[91mError: The original code block to be patched was not found.\033[0m")
69
- print("The demo script may have changed, or there might be a whitespace mismatch.")
70
- print("Please verify the contents of demo/gradio_demo.py.")
71
- sys.exit(1)
 
 
 
 
 
 
72
 
73
  if USE_ZEROGPU:
74
  print("Optimizing for ZeroGPU execution...")
75
 
76
- # New block for ZeroGPU: We remove the problematic `attn_implementation` line.
77
- replacement_lines_gpu = [
 
 
 
 
78
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
79
  ' self.model_path,',
80
  ' torch_dtype=torch.bfloat16,',
81
  " device_map='cuda',",
82
  ' )'
83
  ]
84
- replacement_block_gpu = "\n".join(replacement_lines_gpu)
85
 
86
- # Add 'import spaces' at the beginning of the file for the @spaces.GPU decorator
87
- if "import spaces" not in file_content:
88
- modified_content = "import spaces\n" + file_content
 
 
 
 
 
 
89
  else:
90
- modified_content = file_content
91
-
92
- # Decorate the main interface class to request a GPU from the Spaces infrastructure
93
- if "@spaces.GPU" not in modified_content:
94
- modified_content = modified_content.replace(
95
- "class VibeVoiceDemo:",
96
- "@spaces.GPU(duration=120)\nclass VibeVoiceDemo:"
97
- )
98
-
99
- # Replace the model loading block
100
- modified_content = modified_content.replace(original_block, replacement_block_gpu)
101
- print("Script modified for ZeroGPU successfully.")
102
 
103
  else: # Pure CPU execution
104
  print("Modifying for pure CPU execution...")
105
 
106
- # New block for CPU: Use float32 and map directly to the CPU.
107
- replacement_lines_cpu = [
108
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
109
  ' self.model_path,',
110
  ' torch_dtype=torch.float32, # Use float32 for CPU',
111
  ' device_map="cpu",',
112
  ' )'
113
  ]
114
- replacement_block_cpu = "\n".join(replacement_lines_cpu)
115
 
116
- # Replace the original model loading block with the CPU version
117
- modified_content = file_content.replace(original_block, replacement_block_cpu)
118
- print("Script modified for CPU successfully.")
 
 
 
 
119
 
120
  # Write the dynamically modified content back to the demo file
121
  demo_script_path.write_text(modified_content)
@@ -125,7 +142,7 @@ except Exception as e:
125
  sys.exit(1)
126
 
127
  # --- 4. Launch the Gradio Demo ---
128
- model_id = "microsoft/VibeVoice-1.5B"
129
 
130
  # Construct the command to run the modified demo script
131
  command = [
 
49
  print(f"Reading {demo_script_path} to apply environment-specific modifications...")
50
 
51
  try:
52
+ modified_content = demo_script_path.read_text()
53
 
54
+ # --- Patch Definitions ---
55
+
56
+ # Define the original model loading block to be replaced.
57
+ original_model_lines = [
58
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
59
  ' self.model_path,',
60
  ' torch_dtype=torch.bfloat16,',
 
62
  ' attn_implementation="flash_attention_2",',
63
  ' )'
64
  ]
65
+ original_model_block = "\n".join(original_model_lines)
66
 
67
+ # Define the generation method signature to add the GPU decorator to.
68
+ original_method_lines = [
69
+ ' def generate_podcast_streaming(self, ',
70
+ ' num_speakers: int,',
71
+ ' script: str,',
72
+ ' speaker_1: str = None,',
73
+ ' speaker_2: str = None,',
74
+ ' speaker_3: str = None,',
75
+ ' speaker_4: str = None,',
76
+ ' cfg_scale: float = 1.3) -> Iterator[tuple]:'
77
+ ]
78
+ original_method_signature = "\n".join(original_method_lines)
79
 
80
  if USE_ZEROGPU:
81
  print("Optimizing for ZeroGPU execution...")
82
 
83
+ # Add 'import spaces' if it's not already there.
84
+ if "import spaces" not in modified_content:
85
+ modified_content = "import spaces\n" + modified_content
86
+
87
+ # New block for ZeroGPU model loading: remove `attn_implementation`.
88
+ replacement_model_lines_gpu = [
89
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
90
  ' self.model_path,',
91
  ' torch_dtype=torch.bfloat16,',
92
  " device_map='cuda',",
93
  ' )'
94
  ]
95
+ replacement_model_block_gpu = "\n".join(replacement_model_lines_gpu)
96
 
97
+ # Add the @spaces.GPU decorator to the generation method instead of the class.
98
+ replacement_method_signature_gpu = "@spaces.GPU(duration=120)\n" + original_method_signature
99
+
100
+ # --- Apply Patches for GPU ---
101
+
102
+ # Patch 1: Decorate the generation method
103
+ if original_method_signature in modified_content:
104
+ modified_content = modified_content.replace(original_method_signature, replacement_method_signature_gpu)
105
+ print("Successfully applied GPU decorator to the generation method.")
106
  else:
107
+ print("\033[91mWarning: Could not find the generation method signature to apply the GPU decorator.\033[0m")
108
+
109
+ # Patch 2: Modify the model loading
110
+ if original_model_block in modified_content:
111
+ modified_content = modified_content.replace(original_model_block, replacement_model_block_gpu)
112
+ print("Successfully patched the model loading block for ZeroGPU.")
113
+ else:
114
+ print("\033[91mWarning: The original model loading block was not found. Patching may have failed.\033[0m")
 
 
 
 
115
 
116
  else: # Pure CPU execution
117
  print("Modifying for pure CPU execution...")
118
 
119
+ # New block for CPU: Use float32 and map to CPU.
120
+ replacement_model_lines_cpu = [
121
  ' self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(',
122
  ' self.model_path,',
123
  ' torch_dtype=torch.float32, # Use float32 for CPU',
124
  ' device_map="cpu",',
125
  ' )'
126
  ]
127
+ replacement_model_block_cpu = "\n".join(replacement_model_lines_cpu)
128
 
129
+ # Apply patch for CPU
130
+ if original_model_block in modified_content:
131
+ modified_content = modified_content.replace(original_model_block, replacement_model_block_cpu)
132
+ print("Script modified for CPU successfully.")
133
+ else:
134
+ print("\033[91mWarning: The original model loading block was not found. Patching may have failed.\033[0m")
135
+
136
 
137
  # Write the dynamically modified content back to the demo file
138
  demo_script_path.write_text(modified_content)
 
142
  sys.exit(1)
143
 
144
  # --- 4. Launch the Gradio Demo ---
145
+ model_id = "microsoft/V_VibeVoice-1.5B"
146
 
147
  # Construct the command to run the modified demo script
148
  command = [