broadfield-dev commited on
Commit
c5f9e51
·
verified ·
1 Parent(s): 88cdeea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -49
app.py CHANGED
@@ -34,16 +34,55 @@ except subprocess.CalledProcessError as e:
34
  print(f"Error installing package: {e.stderr}")
35
  sys.exit(1)
36
 
37
- # --- 3. Refactor the demo script using a robust state-machine patcher ---
38
  demo_script_path = Path("demo/gradio_demo.py")
39
  print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
40
 
41
  try:
42
  with open(demo_script_path, 'r') as f:
43
- lines = f.readlines()
44
 
45
- # --- Prepare the code blocks to be inserted ---
46
- lazy_load_code = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Patched: Lazy-load model and processor on the GPU worker
48
  if self.model is None or self.processor is None:
49
  print("Loading processor & model for the first time on GPU worker...")
@@ -55,61 +94,25 @@ try:
55
  )
56
  self.model.eval()
57
  self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
58
- self.model.model.noise_scheduler.config,
59
  algorithm_type='sde-dpmsolver++',
60
  beta_schedule='squaredcos_cap_v2'
61
  )
62
  self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
63
  print("Model and processor loaded successfully on GPU worker.")
64
- """
65
-
66
- # --- Perform the line-by-line modifications using a state machine ---
67
- new_lines = []
68
- # Add 'import spaces' at the top if it doesn't exist
69
- if not any("import spaces" in line for line in lines):
70
- new_lines.append("import spaces\n")
71
 
72
- # State machine variables
73
- in_generate_method = False
74
- patched_generate_method = False
75
-
76
- for line in lines:
77
- # Defer the initial model loading to prevent PicklingError
78
- if "self.load_model()" in line and "def __init__" in "".join(lines[lines.index(line)-2:lines.index(line)]):
79
- new_lines.append(" # self.load_model() # Patched: Defer model loading\n")
80
- new_lines.append(" self.model = None\n")
81
- new_lines.append(" self.processor = None\n")
82
- print("Successfully patched __init__ to prevent startup model load.")
83
-
84
- # Start of the target method
85
- elif "def generate_podcast_streaming(self," in line and not patched_generate_method:
86
- new_lines.append(" @spaces.GPU(duration=120)\n")
87
- new_lines.append(line)
88
- in_generate_method = True
89
-
90
- # End of the target method signature
91
- elif "-> Iterator[tuple]:" in line and in_generate_method:
92
- new_lines.append(line)
93
- # Indent and insert the lazy load code
94
- for code_line in lazy_load_code.strip().split('\n'):
95
- new_lines.append(' ' * 8 + code_line + '\n')
96
-
97
- # Reset state
98
- in_generate_method = False
99
- patched_generate_method = True
100
- print("Successfully patched generation method for lazy loading.")
101
-
102
- # All other lines
103
- else:
104
- new_lines.append(line)
105
-
106
- if not patched_generate_method:
107
- print("\033[91mError: Failed to apply the lazy-loading patch. The target method signature may have changed.\033[0m")
108
  sys.exit(1)
109
 
110
  # --- Write the modified content back to the file ---
111
  with open(demo_script_path, 'w') as f:
112
- f.writelines(new_lines)
113
 
114
  print("Script patching complete.")
115
 
 
34
  print(f"Error installing package: {e.stderr}")
35
  sys.exit(1)
36
 
37
+ # --- 3. Refactor the demo script using a direct replacement strategy ---
38
  demo_script_path = Path("demo/gradio_demo.py")
39
  print(f"Refactoring {demo_script_path} for ZeroGPU lazy loading...")
40
 
41
  try:
42
  with open(demo_script_path, 'r') as f:
43
+ modified_content = f.read()
44
 
45
+ # --- Add 'import spaces' at the top ---
46
+ if "import spaces" not in modified_content:
47
+ modified_content = "import spaces\n" + modified_content
48
+
49
+ # --- Patch 1: Defer model loading in __init__ ---
50
+ original_init_call = " self.load_model()"
51
+ replacement_init_block = (
52
+ " # self.load_model() # Patched: Defer model loading\n"
53
+ " self.model = None\n"
54
+ " self.processor = None"
55
+ )
56
+ if original_init_call in modified_content:
57
+ modified_content = modified_content.replace(original_init_call, replacement_init_block, 1)
58
+ print("Successfully patched __init__ to prevent startup model load.")
59
+ else:
60
+ print(f"\033[91mError: Could not find '{original_init_call}' to patch. Startup patch failed.\033[0m")
61
+ sys.exit(1)
62
+
63
+ # --- Patch 2: Add decorator and lazy-loading logic to the generation method ---
64
+ # Define the exact block to find, spanning the full method signature down to the 'try:'.
65
+ # This is sensitive to whitespace but is the most direct way to replace.
66
+ original_method_header = """ def generate_podcast_streaming(self,
67
+ num_speakers: int,
68
+ script: str,
69
+ speaker_1: str = None,
70
+ speaker_2: str = None,
71
+ speaker_3: str = None,
72
+ speaker_4: str = None,
73
+ cfg_scale: float = 1.3) -> Iterator[tuple]:
74
+ try:"""
75
+
76
+ # Define the full replacement block with correct indentation.
77
+ replacement_method_header = """ @spaces.GPU(duration=120)
78
+ def generate_podcast_streaming(self,
79
+ num_speakers: int,
80
+ script: str,
81
+ speaker_1: str = None,
82
+ speaker_2: str = None,
83
+ speaker_3: str = None,
84
+ speaker_4: str = None,
85
+ cfg_scale: float = 1.3) -> Iterator[tuple]:
86
  # Patched: Lazy-load model and processor on the GPU worker
87
  if self.model is None or self.processor is None:
88
  print("Loading processor & model for the first time on GPU worker...")
 
94
  )
95
  self.model.eval()
96
  self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
97
+ self.model.model.noise_scheduler.config,
98
  algorithm_type='sde-dpmsolver++',
99
  beta_schedule='squaredcos_cap_v2'
100
  )
101
  self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
102
  print("Model and processor loaded successfully on GPU worker.")
 
 
 
 
 
 
 
103
 
104
+ try:"""
105
+
106
+ if original_method_header in modified_content:
107
+ modified_content = modified_content.replace(original_method_header, replacement_method_header, 1)
108
+ print("Successfully patched generation method for lazy loading.")
109
+ else:
110
+ print(f"\033[91mError: Could not find the method definition for 'generate_podcast_streaming' to patch. This is likely due to a whitespace mismatch. Please check the demo script.\033[0m")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  sys.exit(1)
112
 
113
  # --- Write the modified content back to the file ---
114
  with open(demo_script_path, 'w') as f:
115
+ f.write(modified_content)
116
 
117
  print("Script patching complete.")
118