yasserrmd commited on
Commit
f736395
Β·
verified Β·
1 Parent(s): 9a86201

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -69
app.py CHANGED
@@ -69,86 +69,148 @@ class VibeVoiceDemo:
69
  return np.array([])
70
 
71
  @GPU
72
- def generate_podcast(self, num_speakers: int, script: str,
73
- speaker_1: str = None, speaker_2: str = None,
74
- speaker_3: str = None, speaker_4: str = None,
75
- cfg_scale: float = 1.3):
76
- """Final audio generation only (no streaming)."""
77
- self.is_generating = True
78
-
79
- if not script.strip():
80
- raise gr.Error("Please provide a script.")
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- if num_speakers < 1 or num_speakers > 4:
83
- raise gr.Error("Number of speakers must be 1–4.")
84
 
85
- selected = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
86
- for i, sp in enumerate(selected):
87
- if not sp or sp not in self.available_voices:
88
- raise gr.Error(f"Invalid speaker {i+1} selection.")
 
89
 
90
- # load voices
91
- voice_samples = [self.read_audio(self.available_voices[sp]) for sp in selected]
92
- if any(len(v) == 0 for v in voice_samples):
93
- raise gr.Error("Failed to load one or more voice samples.")
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- # format script
96
- lines = script.strip().split("\n")
97
- formatted = []
98
- for i, line in enumerate(lines):
99
- line = line.strip()
100
- if not line:
101
- continue
102
- if line.startswith("Speaker "):
103
- formatted.append(line)
104
- else:
105
- sp_id = i % num_speakers
106
- formatted.append(f"Speaker {sp_id}: {line}")
107
- formatted_script = "\n".join(formatted)
 
 
 
 
 
 
108
 
109
- # processor input
110
- inputs = self.processor(
111
- text=[formatted_script],
112
- voice_samples=[voice_samples],
113
- padding=True,
114
- return_tensors="pt",
115
- return_attention_mask=True,
116
- )
 
117
 
118
- start = time.time()
119
- outputs = self.model.generate(
120
- **inputs,
121
- max_new_tokens=None,
122
- cfg_scale=cfg_scale,
123
- tokenizer=self.processor.tokenizer,
124
- generation_config={'do_sample': False},
125
- verbose=False,
126
- )
 
 
 
127
 
128
- # --- FIX: pull from speech_outputs ---
129
- if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
130
- audio = outputs.speech_outputs[0].cpu().numpy()
131
- else:
132
- self.is_generating = False
133
- raise gr.Error("❌ No audio was generated by the model.")
134
-
135
- if audio.ndim > 1:
136
- audio = audio.squeeze()
137
 
138
- sample_rate = 24000
 
 
 
 
139
 
140
- # Save automatically to disk
141
- os.makedirs("outputs", exist_ok=True)
142
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
143
- file_path = os.path.join("outputs", f"podcast_{timestamp}.wav")
144
- sf.write(file_path, audio, sample_rate)
145
- print(f"πŸ’Ύ Saved podcast to {file_path}")
 
 
 
146
 
147
- total_dur = len(audio) / sample_rate
148
- log = f"βœ… Generation complete in {time.time()-start:.1f}s, {total_dur:.1f}s audio\nSaved to {file_path}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- self.is_generating = False
151
- return (sample_rate, audio), log
 
 
 
 
 
 
152
 
153
 
154
 
 
69
  return np.array([])
70
 
71
  @GPU
72
+ def generate_podcast(self,
73
+ num_speakers: int,
74
+ script: str,
75
+ speaker_1: str = None,
76
+ speaker_2: str = None,
77
+ speaker_3: str = None,
78
+ speaker_4: str = None,
79
+ cfg_scale: float = 1.3):
80
+ """
81
+ Generates a podcast as a single audio file from a script and saves it.
82
+ This is a non-streaming function.
83
+ """
84
+ try:
85
+ # 1. Set generating state and validate inputs
86
+ self.is_generating = True
87
+
88
+ if not script.strip():
89
+ raise gr.Error("Error: Please provide a script.")
90
+
91
+ # Defend against common mistake with apostrophes
92
+ script = script.replace("’", "'")
93
 
94
+ if not 1 <= num_speakers <= 4:
95
+ raise gr.Error("Error: Number of speakers must be between 1 and 4.")
96
 
97
+ # 2. Collect and validate selected speakers
98
+ selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
99
+ for i, speaker_name in enumerate(selected_speakers):
100
+ if not speaker_name or speaker_name not in self.available_voices:
101
+ raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.")
102
 
103
+ # 3. Build initial log
104
+ log = f"πŸŽ™οΈ Generating podcast with {num_speakers} speakers\n"
105
+ log += f"πŸ“Š Parameters: CFG Scale={cfg_scale}\n"
106
+ log += f"🎭 Speakers: {', '.join(selected_speakers)}\n"
107
+
108
+ # 4. Load voice samples
109
+ voice_samples = []
110
+ for speaker_name in selected_speakers:
111
+ audio_path = self.available_voices[speaker_name]
112
+ # Assuming self.read_audio is a method in your class that returns audio data
113
+ audio_data = self.read_audio(audio_path)
114
+ if len(audio_data) == 0:
115
+ raise gr.Error(f"Error: Failed to load audio for {speaker_name}")
116
+ voice_samples.append(audio_data)
117
+
118
+ log += f"βœ… Loaded {len(voice_samples)} voice samples\n"
119
 
120
+ # 5. Parse and format the script
121
+ lines = script.strip().split('\n')
122
+ formatted_script_lines = []
123
+ for line in lines:
124
+ line = line.strip()
125
+ if not line:
126
+ continue
127
+
128
+ # Check if line already has speaker format (e.g., "Speaker 1: ...")
129
+ if line.startswith('Speaker ') and ':' in line:
130
+ formatted_script_lines.append(line)
131
+ else:
132
+ # Auto-assign speakers in rotation
133
+ speaker_id = len(formatted_script_lines) % num_speakers
134
+ formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
135
+
136
+ formatted_script = '\n'.join(formatted_script_lines)
137
+ log += f"πŸ“ Formatted script with {len(formatted_script_lines)} turns\n"
138
+ log += "πŸ”„ Processing with VibeVoice...\n"
139
 
140
+ # 6. Prepare inputs for the model
141
+ # Assuming self.processor is an object available in your class
142
+ inputs = self.processor(
143
+ text=[formatted_script],
144
+ voice_samples=[voice_samples],
145
+ padding=True,
146
+ return_tensors="pt",
147
+ return_attention_mask=True,
148
+ )
149
 
150
+ # 7. Generate audio
151
+ start_time = time.time()
152
+ # Assuming self.model is an object available in your class
153
+ outputs = self.model.generate(
154
+ **inputs,
155
+ max_new_tokens=None,
156
+ cfg_scale=cfg_scale,
157
+ tokenizer=self.processor.tokenizer,
158
+ generation_config={'do_sample': False},
159
+ verbose=False, # Verbose is off for cleaner logs
160
+ )
161
+ generation_time = time.time() - start_time
162
 
163
+ # 8. Extract audio output
164
+ # The generated audio is often in speech_outputs or a similar attribute
165
+ if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None:
166
+ audio_tensor = outputs.speech_outputs[0]
167
+ audio = audio_tensor.cpu().numpy()
168
+ else:
169
+ raise gr.Error("❌ Error: No audio was generated by the model. Please try again.")
 
 
170
 
171
+ # Ensure audio is a 1D array
172
+ if audio.ndim > 1:
173
+ audio = audio.squeeze()
174
+
175
+ sample_rate = 24000 # Standard sample rate for this model
176
 
177
+ # 9. Save the audio file
178
+ output_dir = "outputs"
179
+ os.makedirs(output_dir, exist_ok=True)
180
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
181
+ file_path = os.path.join(output_dir, f"podcast_{timestamp}.wav")
182
+
183
+ # Write the NumPy array to a WAV file
184
+ sf.write(file_path, audio, sample_rate)
185
+ print(f"πŸ’Ύ Podcast saved to {file_path}")
186
 
187
+ # 10. Finalize log and return
188
+ total_duration = len(audio) / sample_rate
189
+ log += f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
190
+ log += f"🎡 Final audio duration: {total_duration:.2f} seconds\n"
191
+ log += f"βœ… Successfully saved podcast to: {file_path}\n"
192
+
193
+ self.is_generating = False
194
+ return (sample_rate, audio), log
195
+
196
+ except gr.Error as e:
197
+ # Handle Gradio-specific errors (for user feedback)
198
+ self.is_generating = False
199
+ error_msg = f"❌ Input Error: {str(e)}"
200
+ print(error_msg)
201
+ # In Gradio, you would typically return an update to the UI
202
+ # For a pure function, we re-raise or handle it as needed.
203
+ # This return signature matches the success case but with error info.
204
+ return None, error_msg
205
 
206
+ except Exception as e:
207
+ # Handle all other unexpected errors
208
+ self.is_generating = False
209
+ error_msg = f"❌ An unexpected error occurred: {str(e)}"
210
+ print(error_msg)
211
+ import traceback
212
+ traceback.print_exc()
213
+ return None, error_msg
214
 
215
 
216