yasserrmd commited on
Commit
3ad533a
·
verified ·
1 Parent(s): afcb698

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -844
app.py CHANGED
@@ -1,124 +1,61 @@
1
- """
2
- VibeVoice Gradio Demo - High-Quality Dialogue Generation Interface with Streaming Support
3
- """
4
-
5
  import argparse
6
- import json
7
  import os
8
- import sys
9
- import tempfile
10
  import time
11
- from pathlib import Path
12
- from typing import List, Dict, Any, Iterator
13
- from datetime import datetime
14
- import threading
15
  import numpy as np
16
  import gradio as gr
17
  import librosa
18
  import soundfile as sf
19
  import torch
20
- import os
21
  import traceback
22
- import spaces
23
 
24
  from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
25
  from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
26
  from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
27
- from vibevoice.modular.streamer import AudioStreamer
28
  from transformers.utils import logging
29
  from transformers import set_seed
30
 
31
  logging.set_verbosity_info()
32
  logger = logging.get_logger(__name__)
33
 
34
- # import os
35
- # os.environ["FLASH_ATTENTION_2"] = "0"
36
-
37
 
38
  class VibeVoiceDemo:
39
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
40
- """Initialize the VibeVoice demo with model loading."""
41
  self.model_path = model_path
42
  self.device = device
43
  self.inference_steps = inference_steps
44
- self.is_generating = False # Track generation state
45
- self.stop_generation = False # Flag to stop generation
46
- self.current_streamer = None # Track current audio streamer
 
47
  self.load_model()
48
  self.setup_voice_presets()
49
- self.load_example_scripts() # Load example scripts
50
-
51
  def load_model(self):
52
- """Load the VibeVoice model and processor."""
53
  print(f"Loading processor & model from {self.model_path}")
54
-
55
- # Load processor
56
- self.processor = VibeVoiceProcessor.from_pretrained(
57
- self.model_path,
58
- )
59
-
60
- # Load model
61
  self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
62
  self.model_path,
63
  torch_dtype=torch.bfloat16,
64
- device_map='cuda'
65
  )
66
  self.model.eval()
67
-
68
- # Use SDE solver by default
69
- self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
70
- self.model.model.noise_scheduler.config,
71
- algorithm_type='sde-dpmsolver++',
72
- beta_schedule='squaredcos_cap_v2'
73
- )
74
  self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
75
-
76
- if hasattr(self.model.model, 'language_model'):
77
- print(f"Language model attention: {self.model.model.language_model.config._attn_implementation}")
78
-
79
  def setup_voice_presets(self):
80
- """Setup voice presets by scanning the voices directory."""
81
  voices_dir = os.path.join(os.path.dirname(__file__), "voices")
82
-
83
- # Check if voices directory exists
84
  if not os.path.exists(voices_dir):
85
  print(f"Warning: Voices directory not found at {voices_dir}")
86
- self.voice_presets = {}
87
- self.available_voices = {}
88
  return
89
-
90
- # Scan for all WAV files in the voices directory
91
- self.voice_presets = {}
92
-
93
- # Get all .wav files in the voices directory
94
- wav_files = [f for f in os.listdir(voices_dir)
95
- if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))]
96
-
97
- # Create dictionary with filename (without extension) as key
98
  for wav_file in wav_files:
99
- # Remove .wav extension to get the name
100
  name = os.path.splitext(wav_file)[0]
101
- # Create full path
102
- full_path = os.path.join(voices_dir, wav_file)
103
- self.voice_presets[name] = full_path
104
-
105
- # Sort the voice presets alphabetically by name for better UI
106
- self.voice_presets = dict(sorted(self.voice_presets.items()))
107
-
108
- # Filter out voices that don't exist (this is now redundant but kept for safety)
109
- self.available_voices = {
110
- name: path for name, path in self.voice_presets.items()
111
- if os.path.exists(path)
112
- }
113
-
114
- if not self.available_voices:
115
- raise gr.Error("No voice presets found. Please add .wav files to the demo/voices directory.")
116
-
117
- print(f"Found {len(self.available_voices)} voice files in {voices_dir}")
118
- print(f"Available voices: {', '.join(self.available_voices.keys())}")
119
-
120
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
121
- """Read and preprocess audio file."""
122
  try:
123
  wav, sr = sf.read(audio_path)
124
  if len(wav.shape) > 1:
@@ -129,14 +66,13 @@ class VibeVoiceDemo:
129
  except Exception as e:
130
  print(f"Error reading audio {audio_path}: {e}")
131
  return np.array([])
132
-
133
- @spaces.GPU
134
  def generate_podcast(self, num_speakers: int, script: str,
135
  speaker_1: str = None, speaker_2: str = None,
136
  speaker_3: str = None, speaker_4: str = None,
137
  cfg_scale: float = 1.3):
138
- """Single GPU function for full generation (streaming + final)."""
139
- self.stop_generation = False
140
  self.is_generating = True
141
 
142
  if not script.strip():
@@ -150,7 +86,6 @@ class VibeVoiceDemo:
150
  if not sp or sp not in self.available_voices:
151
  raise gr.Error(f"Invalid speaker {i+1} selection.")
152
 
153
- # load voices
154
  voice_samples = [self.read_audio(self.available_voices[sp]) for sp in selected]
155
  if any(len(v) == 0 for v in voice_samples):
156
  raise gr.Error("Failed to load one or more voice samples.")
@@ -177,539 +112,103 @@ class VibeVoiceDemo:
177
  return_tensors="pt"
178
  )
179
 
180
- # === direct generation with streamer ===
181
- from vibevoice import AudioStreamer, convert_to_16_bit_wav
182
- audio_streamer = AudioStreamer(batch_size=1)
183
  start = time.time()
184
  outputs = self.model.generate(
185
  **inputs,
186
  cfg_scale=cfg_scale,
187
  tokenizer=self.processor.tokenizer,
188
- audio_streamer=audio_streamer,
189
  verbose=False
190
  )
191
 
192
- sample_rate = 24000
193
- audio_stream = audio_streamer.get_stream(0)
194
- all_chunks, pending = [], []
195
- min_chunk_size = sample_rate * 2
196
- last_yield = time.time()
197
-
198
- for chunk in audio_stream:
199
- if torch.is_tensor(chunk):
200
- chunk = chunk.float().cpu().numpy()
201
- if chunk.ndim > 1:
202
- chunk = chunk.squeeze()
203
- chunk16 = convert_to_16_bit_wav(chunk)
204
- all_chunks.append(chunk16)
205
- pending.append(chunk16)
206
- if sum(len(c) for c in pending) >= min_chunk_size or (time.time() - last_yield) > 5:
207
- new_audio = np.concatenate(pending)
208
- yield (sample_rate, new_audio), None, f"Streaming {len(all_chunks)} chunks..."
209
- pending = []
210
- last_yield = time.time()
211
-
212
- if all_chunks:
213
- complete = np.concatenate(all_chunks)
214
- total_dur = len(complete) / sample_rate
215
- log = f"✅ Generation complete in {time.time()-start:.1f}s, {total_dur:.1f}s audio"
216
- yield None, (sample_rate, complete), log
217
  else:
218
- yield None, None, "❌ No audio generated."
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  self.is_generating = False
221
-
222
 
223
- def stop_audio_generation(self):
224
- """Stop the current audio generation process."""
225
- self.stop_generation = True
226
- if self.current_streamer is not None:
227
- try:
228
- self.current_streamer.end()
229
- except Exception as e:
230
- print(f"Error stopping streamer: {e}")
231
- print("🛑 Audio generation stop requested")
232
-
233
  def load_example_scripts(self):
234
- """Load example scripts from the text_examples directory."""
235
  examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
236
  self.example_scripts = []
237
-
238
- # Check if text_examples directory exists
239
  if not os.path.exists(examples_dir):
240
- print(f"Warning: text_examples directory not found at {examples_dir}")
241
  return
242
-
243
- # Get all .txt files in the text_examples directory
244
- txt_files = sorted([f for f in os.listdir(examples_dir)
245
- if f.lower().endswith('.txt') and os.path.isfile(os.path.join(examples_dir, f))])
246
-
247
  for txt_file in txt_files:
248
- file_path = os.path.join(examples_dir, txt_file)
249
-
250
- import re
251
- # Check if filename contains a time pattern like "45min", "90min", etc.
252
- time_pattern = re.search(r'(\d+)min', txt_file.lower())
253
- if time_pattern:
254
- minutes = int(time_pattern.group(1))
255
- if minutes > 15:
256
- print(f"Skipping {txt_file}: duration {minutes} minutes exceeds 15-minute limit")
257
- continue
258
-
259
  try:
260
- with open(file_path, 'r', encoding='utf-8') as f:
261
  script_content = f.read().strip()
262
-
263
- # Remove empty lines and lines with only whitespace
264
- script_content = '\n'.join(line for line in script_content.split('\n') if line.strip())
265
-
266
- if not script_content:
267
- continue
268
-
269
- # Parse the script to determine number of speakers
270
- num_speakers = self._get_num_speakers_from_script(script_content)
271
-
272
- # Add to examples list as [num_speakers, script_content]
273
- self.example_scripts.append([num_speakers, script_content])
274
- print(f"Loaded example: {txt_file} with {num_speakers} speakers")
275
-
276
  except Exception as e:
277
- print(f"Error loading example script {txt_file}: {e}")
278
-
279
- if self.example_scripts:
280
- print(f"Successfully loaded {len(self.example_scripts)} example scripts")
281
- else:
282
- print("No example scripts were loaded")
283
-
284
- def _get_num_speakers_from_script(self, script: str) -> int:
285
- """Determine the number of unique speakers in a script."""
286
- import re
287
- speakers = set()
288
-
289
- lines = script.strip().split('\n')
290
- for line in lines:
291
- # Use regex to find speaker patterns
292
- match = re.match(r'^Speaker\s+(\d+)\s*:', line.strip(), re.IGNORECASE)
293
- if match:
294
- speaker_id = int(match.group(1))
295
- speakers.add(speaker_id)
296
-
297
- # If no speakers found, default to 1
298
- if not speakers:
299
- return 1
300
-
301
- # Return the maximum speaker ID + 1 (assuming 0-based indexing)
302
- # or the count of unique speakers if they're 1-based
303
- max_speaker = max(speakers)
304
- min_speaker = min(speakers)
305
-
306
- if min_speaker == 0:
307
- return max_speaker + 1
308
- else:
309
- # Assume 1-based indexing, return the count
310
- return len(speakers)
311
-
312
 
313
  def create_demo_interface(demo_instance: VibeVoiceDemo):
314
- """Create the Gradio interface with streaming support."""
315
-
316
- # Custom CSS for high-end aesthetics with lighter theme
317
- custom_css = """
318
- /* Modern light theme with gradients */
319
- .gradio-container {
320
- background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
321
- font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif;
322
- }
323
-
324
- /* Header styling */
325
- .main-header {
326
- background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
327
- padding: 2rem;
328
- border-radius: 20px;
329
- margin-bottom: 2rem;
330
- text-align: center;
331
- box-shadow: 0 10px 40px rgba(102, 126, 234, 0.3);
332
- }
333
-
334
- .main-header h1 {
335
- color: white;
336
- font-size: 2.5rem;
337
- font-weight: 700;
338
- margin: 0;
339
- text-shadow: 0 2px 4px rgba(0,0,0,0.3);
340
- }
341
-
342
- .main-header p {
343
- color: rgba(255,255,255,0.9);
344
- font-size: 1.1rem;
345
- margin: 0.5rem 0 0 0;
346
- }
347
-
348
- /* Card styling */
349
- .settings-card, .generation-card {
350
- background: rgba(255, 255, 255, 0.8);
351
- backdrop-filter: blur(10px);
352
- border: 1px solid rgba(226, 232, 240, 0.8);
353
- border-radius: 16px;
354
- padding: 1.5rem;
355
- margin-bottom: 1rem;
356
- box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
357
- }
358
-
359
- /* Speaker selection styling */
360
- .speaker-grid {
361
- display: grid;
362
- gap: 1rem;
363
- margin-bottom: 1rem;
364
- }
365
-
366
- .speaker-item {
367
- background: linear-gradient(135deg, #e2e8f0 0%, #cbd5e1 100%);
368
- border: 1px solid rgba(148, 163, 184, 0.4);
369
- border-radius: 12px;
370
- padding: 1rem;
371
- color: #374151;
372
- font-weight: 500;
373
- }
374
-
375
- /* Streaming indicator */
376
- .streaming-indicator {
377
- display: inline-block;
378
- width: 10px;
379
- height: 10px;
380
- background: #22c55e;
381
- border-radius: 50%;
382
- margin-right: 8px;
383
- animation: pulse 1.5s infinite;
384
- }
385
-
386
- @keyframes pulse {
387
- 0% { opacity: 1; transform: scale(1); }
388
- 50% { opacity: 0.5; transform: scale(1.1); }
389
- 100% { opacity: 1; transform: scale(1); }
390
- }
391
-
392
- /* Queue status styling */
393
- .queue-status {
394
- background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
395
- border: 1px solid rgba(14, 165, 233, 0.3);
396
- border-radius: 8px;
397
- padding: 0.75rem;
398
- margin: 0.5rem 0;
399
- text-align: center;
400
- font-size: 0.9rem;
401
- color: #0369a1;
402
- }
403
-
404
- .generate-btn {
405
- background: linear-gradient(135deg, #059669 0%, #0d9488 100%);
406
- border: none;
407
- border-radius: 12px;
408
- padding: 1rem 2rem;
409
- color: white;
410
- font-weight: 600;
411
- font-size: 1.1rem;
412
- box-shadow: 0 4px 20px rgba(5, 150, 105, 0.4);
413
- transition: all 0.3s ease;
414
- }
415
-
416
- .generate-btn:hover {
417
- transform: translateY(-2px);
418
- box-shadow: 0 6px 25px rgba(5, 150, 105, 0.6);
419
- }
420
-
421
- .stop-btn {
422
- background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
423
- border: none;
424
- border-radius: 12px;
425
- padding: 1rem 2rem;
426
- color: white;
427
- font-weight: 600;
428
- font-size: 1.1rem;
429
- box-shadow: 0 4px 20px rgba(239, 68, 68, 0.4);
430
- transition: all 0.3s ease;
431
- }
432
-
433
- .stop-btn:hover {
434
- transform: translateY(-2px);
435
- box-shadow: 0 6px 25px rgba(239, 68, 68, 0.6);
436
- }
437
-
438
- /* Audio player styling */
439
- .audio-output {
440
- background: linear-gradient(135deg, #f1f5f9 0%, #e2e8f0 100%);
441
- border-radius: 16px;
442
- padding: 1.5rem;
443
- border: 1px solid rgba(148, 163, 184, 0.3);
444
- }
445
-
446
- .complete-audio-section {
447
- margin-top: 1rem;
448
- padding: 1rem;
449
- background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%);
450
- border: 1px solid rgba(34, 197, 94, 0.3);
451
- border-radius: 12px;
452
- }
453
-
454
- /* Text areas */
455
- .script-input, .log-output {
456
- background: rgba(255, 255, 255, 0.9) !important;
457
- border: 1px solid rgba(148, 163, 184, 0.4) !important;
458
- border-radius: 12px !important;
459
- color: #1e293b !important;
460
- font-family: 'JetBrains Mono', monospace !important;
461
- }
462
-
463
- .script-input::placeholder {
464
- color: #64748b !important;
465
- }
466
-
467
- /* Sliders */
468
- .slider-container {
469
- background: rgba(248, 250, 252, 0.8);
470
- border: 1px solid rgba(226, 232, 240, 0.6);
471
- border-radius: 8px;
472
- padding: 1rem;
473
- margin: 0.5rem 0;
474
- }
475
-
476
- /* Labels and text */
477
- .gradio-container label {
478
- color: #374151 !important;
479
- font-weight: 600 !important;
480
- }
481
-
482
- .gradio-container .markdown {
483
- color: #1f2937 !important;
484
- }
485
-
486
- /* Responsive design */
487
- @media (max-width: 768px) {
488
- .main-header h1 { font-size: 2rem; }
489
- .settings-card, .generation-card { padding: 1rem; }
490
- }
491
-
492
- /* Random example button styling - more subtle professional color */
493
- .random-btn {
494
- background: linear-gradient(135deg, #64748b 0%, #475569 100%);
495
- border: none;
496
- border-radius: 12px;
497
- padding: 1rem 1.5rem;
498
- color: white;
499
- font-weight: 600;
500
- font-size: 1rem;
501
- box-shadow: 0 4px 20px rgba(100, 116, 139, 0.3);
502
- transition: all 0.3s ease;
503
- display: inline-flex;
504
- align-items: center;
505
- gap: 0.5rem;
506
- }
507
-
508
- .random-btn:hover {
509
- transform: translateY(-2px);
510
- box-shadow: 0 6px 25px rgba(100, 116, 139, 0.4);
511
- background: linear-gradient(135deg, #475569 0%, #334155 100%);
512
- }
513
- """
514
-
515
  with gr.Blocks(
516
  title="VibeVoice - AI Podcast Generator",
517
- css=custom_css,
518
- theme=gr.themes.Soft(
519
- primary_hue="blue",
520
- secondary_hue="purple",
521
- neutral_hue="slate",
522
- )
523
  ) as interface:
524
-
525
- # Header
526
- gr.HTML("""
527
- <div class="main-header">
528
- <h1>🎙️ Vibe Podcasting </h1>
529
- <p>Generating Long-form Multi-speaker AI Podcast with VibeVoice</p>
530
- </div>
531
- """)
532
-
533
- with gr.Row():
534
- # Left column - Settings
535
- with gr.Column(scale=1, elem_classes="settings-card"):
536
- gr.Markdown("### 🎛️ **Podcast Settings**")
537
-
538
- # Number of speakers
539
- num_speakers = gr.Slider(
540
- minimum=1,
541
- maximum=4,
542
- value=2,
543
- step=1,
544
- label="Number of Speakers",
545
- elem_classes="slider-container"
546
- )
547
-
548
- # Speaker selection
549
- gr.Markdown("### 🎭 **Speaker Selection**")
550
-
551
- available_speaker_names = list(demo_instance.available_voices.keys())
552
- # default_speakers = available_speaker_names[:4] if len(available_speaker_names) >= 4 else available_speaker_names
553
- default_speakers = ['en-Alice_woman', 'en-Carter_man', 'en-Frank_man', 'en-Maya_woman']
554
-
555
- speaker_selections = []
556
- for i in range(4):
557
- default_value = default_speakers[i] if i < len(default_speakers) else None
558
- speaker = gr.Dropdown(
559
- choices=available_speaker_names,
560
- value=default_value,
561
- label=f"Speaker {i+1}",
562
- visible=(i < 2), # Initially show only first 2 speakers
563
- elem_classes="speaker-item"
564
- )
565
- speaker_selections.append(speaker)
566
-
567
- # Advanced settings
568
- gr.Markdown("### ⚙️ **Advanced Settings**")
569
-
570
- # Sampling parameters (contains all generation settings)
571
- with gr.Accordion("Generation Parameters", open=False):
572
- cfg_scale = gr.Slider(
573
- minimum=1.0,
574
- maximum=2.0,
575
- value=1.3,
576
- step=0.05,
577
- label="CFG Scale (Guidance Strength)",
578
- # info="Higher values increase adherence to text",
579
- elem_classes="slider-container"
580
- )
581
-
582
- # Right column - Generation
583
- with gr.Column(scale=2, elem_classes="generation-card"):
584
- gr.Markdown("### 📝 **Script Input**")
585
-
586
- script_input = gr.Textbox(
587
- label="Conversation Script",
588
- placeholder="""Enter your podcast script here. You can format it as:
589
-
590
- Speaker 0: Welcome to our podcast today!
591
- Speaker 1: Thanks for having me. I'm excited to discuss...
592
-
593
- Or paste text directly and it will auto-assign speakers.""",
594
- lines=12,
595
- max_lines=20,
596
- elem_classes="script-input"
597
- )
598
-
599
- # Button row with Random Example on the left and Generate on the right
600
- with gr.Row():
601
- # Random example button (now on the left)
602
- random_example_btn = gr.Button(
603
- "🎲 Random Example",
604
- size="lg",
605
- variant="secondary",
606
- elem_classes="random-btn",
607
- scale=1 # Smaller width
608
- )
609
-
610
- # Generate button (now on the right)
611
- generate_btn = gr.Button(
612
- "🚀 Generate Podcast",
613
- size="lg",
614
- variant="primary",
615
- elem_classes="generate-btn",
616
- scale=2 # Wider than random button
617
- )
618
-
619
- # Stop button
620
- stop_btn = gr.Button(
621
- "🛑 Stop Generation",
622
- size="lg",
623
- variant="stop",
624
- elem_classes="stop-btn",
625
- visible=False
626
- )
627
-
628
- # Streaming status indicator
629
- streaming_status = gr.HTML(
630
- value="""
631
- <div style="background: linear-gradient(135deg, #dcfce7 0%, #bbf7d0 100%);
632
- border: 1px solid rgba(34, 197, 94, 0.3);
633
- border-radius: 8px;
634
- padding: 0.75rem;
635
- margin: 0.5rem 0;
636
- text-align: center;
637
- font-size: 0.9rem;
638
- color: #166534;">
639
- <span class="streaming-indicator"></span>
640
- <strong>LIVE STREAMING</strong> - Audio is being generated in real-time
641
- </div>
642
- """,
643
- visible=False,
644
- elem_id="streaming-status"
645
- )
646
-
647
- # Output section
648
- gr.Markdown("### 🎵 **Generated Podcast**")
649
-
650
- # Streaming audio output (outside of tabs for simpler handling)
651
- audio_output = gr.Audio(
652
- label="Streaming Audio (Real-time)",
653
- type="numpy",
654
- elem_classes="audio-output",
655
- streaming=True, # Enable streaming mode
656
- autoplay=True,
657
- show_download_button=False, # Explicitly show download button
658
- visible=True
659
- )
660
-
661
- # Complete audio output (non-streaming)
662
- complete_audio_output = gr.Audio(
663
- label="Complete Podcast (Download after generation)",
664
- type="numpy",
665
- elem_classes="audio-output complete-audio-section",
666
- streaming=False, # Non-streaming mode
667
- autoplay=False,
668
- show_download_button=True, # Explicitly show download button
669
- visible=False # Initially hidden, shown when audio is ready
670
- )
671
-
672
- gr.Markdown("""
673
- *💡 **Streaming**: Audio plays as it's being generated (may have slight pauses)
674
- *💡 **Complete Audio**: Will appear below after generation finishes*
675
- """)
676
-
677
- # Generation log
678
- log_output = gr.Textbox(
679
- label="Generation Log",
680
- lines=8,
681
- max_lines=15,
682
- interactive=False,
683
- elem_classes="log-output"
684
- )
685
-
686
- def update_speaker_visibility(num_speakers):
687
- updates = []
688
- for i in range(4):
689
- updates.append(gr.update(visible=(i < num_speakers)))
690
- return updates
691
-
692
- num_speakers.change(
693
- fn=update_speaker_visibility,
694
- inputs=[num_speakers],
695
- outputs=speaker_selections
696
  )
697
-
698
- # Main generation function with streaming
699
  def generate_podcast_wrapper(num_speakers, script, *speakers_and_params):
700
- """Wrapper function to handle the streaming generation call."""
701
  try:
702
- # Extract speakers and parameters
703
- speakers = speakers_and_params[:4] # First 4 are speaker selections
704
- cfg_scale = speakers_and_params[4] # CFG scale
705
-
706
- # Clear outputs and reset visibility at start
707
- yield None, gr.update(value=None, visible=False), "🎙️ Starting generation...", gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
708
-
709
- # The generator will yield multiple times
710
- final_log = "Starting generation..."
711
-
712
- for streaming_audio, complete_audio, log, streaming_visible in demo_instance.generate_podcast(
713
  num_speakers=int(num_speakers),
714
  script=script,
715
  speaker_1=speakers[0],
@@ -717,280 +216,37 @@ Or paste text directly and it will auto-assign speakers.""",
717
  speaker_3=speakers[2],
718
  speaker_4=speakers[3],
719
  cfg_scale=cfg_scale
720
- ):
721
- final_log = log
722
-
723
- # Check if we have complete audio (final yield)
724
- if complete_audio is not None:
725
- # Final state: clear streaming, show complete audio
726
- yield None, gr.update(value=complete_audio, visible=True), log, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
727
- else:
728
- # Streaming state: update streaming audio only
729
- if streaming_audio is not None:
730
- yield streaming_audio, gr.update(visible=False), log, streaming_visible, gr.update(visible=False), gr.update(visible=True)
731
- else:
732
- # No new audio, just update status
733
- yield None, gr.update(visible=False), log, streaming_visible, gr.update(visible=False), gr.update(visible=True)
734
-
735
  except Exception as e:
736
- error_msg = f"❌ A critical error occurred in the wrapper: {str(e)}"
737
- print(error_msg)
738
- import traceback
739
  traceback.print_exc()
740
- # Reset button states on error
741
- yield None, gr.update(value=None, visible=False), error_msg, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
742
-
743
- def stop_generation_handler():
744
- """Handle stopping generation."""
745
- demo_instance.stop_audio_generation()
746
- # Return values for: log_output, streaming_status, generate_btn, stop_btn
747
- return "🛑 Generation stopped.", gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
748
-
749
- # Add a clear audio function
750
- def clear_audio_outputs():
751
- """Clear both audio outputs before starting new generation."""
752
- return None, gr.update(value=None, visible=False)
753
-
754
- # Connect generation button with streaming outputs
755
  generate_btn.click(
756
- fn=clear_audio_outputs,
757
- inputs=[],
758
- outputs=[audio_output, complete_audio_output],
759
- queue=False
760
- ).then(
761
  fn=generate_podcast_wrapper,
762
  inputs=[num_speakers, script_input] + speaker_selections + [cfg_scale],
763
- outputs=[audio_output, complete_audio_output, log_output, streaming_status, generate_btn, stop_btn],
764
- queue=True # Enable Gradio's built-in queue
765
- )
766
-
767
- # Connect stop button
768
- stop_btn.click(
769
- fn=stop_generation_handler,
770
- inputs=[],
771
- outputs=[log_output, streaming_status, generate_btn, stop_btn],
772
- queue=False # Don't queue stop requests
773
- ).then(
774
- # Clear both audio outputs after stopping
775
- fn=lambda: (None, None),
776
- inputs=[],
777
- outputs=[audio_output, complete_audio_output],
778
- queue=False
779
  )
780
-
781
- # Function to randomly select an example
782
- def load_random_example():
783
- """Randomly select and load an example script."""
784
- import random
785
-
786
- # Get available examples
787
- if hasattr(demo_instance, 'example_scripts') and demo_instance.example_scripts:
788
- example_scripts = demo_instance.example_scripts
789
- else:
790
- # Fallback to default
791
- example_scripts = [
792
- [2, "Speaker 0: Welcome to our AI podcast demonstration!\nSpeaker 1: Thanks for having me. This is exciting!"]
793
- ]
794
-
795
- # Randomly select one
796
- if example_scripts:
797
- selected = random.choice(example_scripts)
798
- num_speakers_value = selected[0]
799
- script_value = selected[1]
800
-
801
- # Return the values to update the UI
802
- return num_speakers_value, script_value
803
-
804
- # Default values if no examples
805
- return 2, ""
806
-
807
- # Connect random example button
808
- random_example_btn.click(
809
- fn=load_random_example,
810
- inputs=[],
811
- outputs=[num_speakers, script_input],
812
- queue=False # Don't queue this simple operation
813
- )
814
-
815
- # Add usage tips
816
- gr.Markdown("""
817
- ### 💡 **Usage Tips**
818
-
819
- - Click **🚀 Generate Podcast** to start audio generation
820
- - **Live Streaming** tab shows audio as it's generated (may have slight pauses)
821
- - **Complete Audio** tab provides the full, uninterrupted podcast after generation
822
- - During generation, you can click **🛑 Stop Generation** to interrupt the process
823
- - The streaming indicator shows real-time generation progress
824
- """)
825
-
826
- # Add example scripts
827
- gr.Markdown("### 📚 **Example Scripts**")
828
-
829
- # Use dynamically loaded examples if available, otherwise provide a default
830
- if hasattr(demo_instance, 'example_scripts') and demo_instance.example_scripts:
831
- example_scripts = demo_instance.example_scripts
832
- else:
833
- # Fallback to a simple default example if no scripts loaded
834
- example_scripts = [
835
- [1, "Speaker 1: Welcome to our AI podcast demonstration! This is a sample script showing how VibeVoice can generate natural-sounding speech."]
836
- ]
837
-
838
- gr.Examples(
839
- examples=example_scripts,
840
- inputs=[num_speakers, script_input],
841
- label="Try these example scripts:"
842
- )
843
-
844
- return interface
845
 
 
846
 
847
- def convert_to_16_bit_wav(data):
848
- # Check if data is a tensor and move to cpu
849
- if torch.is_tensor(data):
850
- data = data.detach().cpu().numpy()
851
-
852
- # Ensure data is numpy array
853
- data = np.array(data)
854
-
855
- # Normalize to range [-1, 1] if it's not already
856
- if np.max(np.abs(data)) > 1.0:
857
- data = data / np.max(np.abs(data))
858
-
859
- # Scale to 16-bit integer range
860
- data = (data * 32767).astype(np.int16)
861
- return data
862
-
863
-
864
- def parse_args():
865
- parser = argparse.ArgumentParser(description="VibeVoice Gradio Demo")
866
- parser.add_argument(
867
- "--model_path",
868
- type=str,
869
- default="/tmp/vibevoice-model",
870
- help="Path to the VibeVoice model directory",
871
- )
872
- parser.add_argument(
873
- "--device",
874
- type=str,
875
- default="cuda" if torch.cuda.is_available() else "cpu",
876
- help="Device for inference",
877
- )
878
- parser.add_argument(
879
- "--inference_steps",
880
- type=int,
881
- default=10,
882
- help="Number of inference steps for DDPM (not exposed to users)",
883
- )
884
- parser.add_argument(
885
- "--share",
886
- action="store_true",
887
- help="Share the demo publicly via Gradio",
888
- )
889
- parser.add_argument(
890
- "--port",
891
- type=int,
892
- default=7860,
893
- help="Port to run the demo on",
894
- )
895
-
896
- return parser.parse_args()
897
-
898
-
899
- def main():
900
- """Main function to run the demo."""
901
- args = parse_args()
902
-
903
- set_seed(42) # Set a fixed seed for reproducibility
904
-
905
- print("🎙️ Initializing VibeVoice Demo with Streaming Support...")
906
-
907
- # Initialize demo instance
908
- demo_instance = VibeVoiceDemo(
909
- model_path=args.model_path,
910
- device=args.device,
911
- inference_steps=args.inference_steps
912
- )
913
-
914
- # Create interface
915
- interface = create_demo_interface(demo_instance)
916
-
917
- print(f"🚀 Launching demo on port {args.port}")
918
- print(f"📁 Model path: {args.model_path}")
919
- print(f"🎭 Available voices: {len(demo_instance.available_voices)}")
920
- print(f"🔴 Streaming mode: ENABLED")
921
- print(f"🔒 Session isolation: ENABLED")
922
-
923
- # Launch the interface
924
- try:
925
- interface.queue(
926
- max_size=20, # Maximum queue size
927
- default_concurrency_limit=1 # Process one request at a time
928
- ).launch(
929
- share=args.share,
930
- # server_port=args.port,
931
- server_name="0.0.0.0" if args.share else "127.0.0.1",
932
- show_error=True,
933
- show_api=False # Hide API docs for cleaner interface
934
- )
935
- except KeyboardInterrupt:
936
- print("\n🛑 Shutting down gracefully...")
937
- except Exception as e:
938
- print(f"❌ Server error: {e}")
939
- raise
940
 
941
  def run_demo(
942
  model_path: str = "microsoft/VibeVoice-1.5B",
943
  device: str = "cuda",
944
  inference_steps: int = 5,
945
  share: bool = True,
946
- ) -> None:
947
- """
948
- Run the VibeVoice demo without any command-line arguments.
949
- - share=True exposes the app publicly via a Gradio share link.
950
- - Default Gradio port (7860) is used automatically.
951
- - Errors are shown to help with debugging.
952
- """
953
  set_seed(42)
954
-
955
- print("🎙️ Initializing VibeVoice Demo with Streaming Support...")
956
-
957
- # Initialize demo instance
958
- demo_instance = VibeVoiceDemo(
959
- model_path=model_path,
960
- device=device,
961
- inference_steps=inference_steps
962
- )
963
-
964
- # Build UI
965
  interface = create_demo_interface(demo_instance)
966
-
967
- # Info
968
- print("🚀 Launching demo")
969
- print(f"📁 Model path: {model_path}")
970
- print(f"🎭 Available voices: {len(getattr(demo_instance, 'available_voices', []))}")
971
- print(f"🔴 Streaming mode: ENABLED")
972
- print(f"🔒 Session isolation: ENABLED")
973
-
974
- # Launch (no server_port specified → default 7860)
975
- try:
976
- interface.queue(
977
- max_size=20,
978
- default_concurrency_limit=1
979
- ).launch(
980
- share=share,
981
- server_name="0.0.0.0" if share else "127.0.0.1",
982
- show_error=True, # show full tracebacks (debug-friendly)
983
- show_api=False # cleaner interface
984
- )
985
- except KeyboardInterrupt:
986
- print("\n🛑 Shutting down gracefully...")
987
- except Exception as e:
988
- print(f"❌ Server error: {e}")
989
- raise
990
 
991
 
992
- # Run automatically when this file is executed (no CLI needed)
993
  if __name__ == "__main__":
994
  run_demo()
995
-
996
-
 
 
 
 
 
1
  import argparse
 
2
  import os
 
 
3
  import time
 
 
 
 
4
  import numpy as np
5
  import gradio as gr
6
  import librosa
7
  import soundfile as sf
8
  import torch
 
9
  import traceback
10
+ from spaces import GPU
11
 
12
  from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
13
  from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
14
  from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
 
15
  from transformers.utils import logging
16
  from transformers import set_seed
17
 
18
  logging.set_verbosity_info()
19
  logger = logging.get_logger(__name__)
20
 
 
 
 
21
 
22
  class VibeVoiceDemo:
23
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
 
24
  self.model_path = model_path
25
  self.device = device
26
  self.inference_steps = inference_steps
27
+ self.is_generating = False
28
+ self.processor = None
29
+ self.model = None
30
+ self.available_voices = {}
31
  self.load_model()
32
  self.setup_voice_presets()
33
+ self.load_example_scripts()
34
+
35
  def load_model(self):
 
36
  print(f"Loading processor & model from {self.model_path}")
37
+ self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
 
 
 
 
 
 
38
  self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
39
  self.model_path,
40
  torch_dtype=torch.bfloat16,
41
+ device_map=self.device
42
  )
43
  self.model.eval()
 
 
 
 
 
 
 
44
  self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
45
+
 
 
 
46
  def setup_voice_presets(self):
 
47
  voices_dir = os.path.join(os.path.dirname(__file__), "voices")
 
 
48
  if not os.path.exists(voices_dir):
49
  print(f"Warning: Voices directory not found at {voices_dir}")
 
 
50
  return
51
+ wav_files = [f for f in os.listdir(voices_dir)
52
+ if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac'))]
 
 
 
 
 
 
 
53
  for wav_file in wav_files:
 
54
  name = os.path.splitext(wav_file)[0]
55
+ self.available_voices[name] = os.path.join(voices_dir, wav_file)
56
+ print(f"Voices loaded: {list(self.available_voices.keys())}")
57
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
 
59
  try:
60
  wav, sr = sf.read(audio_path)
61
  if len(wav.shape) > 1:
 
66
  except Exception as e:
67
  print(f"Error reading audio {audio_path}: {e}")
68
  return np.array([])
69
+
70
+ @GPU
71
  def generate_podcast(self, num_speakers: int, script: str,
72
  speaker_1: str = None, speaker_2: str = None,
73
  speaker_3: str = None, speaker_4: str = None,
74
  cfg_scale: float = 1.3):
75
+ """Final audio generation only (no streaming)."""
 
76
  self.is_generating = True
77
 
78
  if not script.strip():
 
86
  if not sp or sp not in self.available_voices:
87
  raise gr.Error(f"Invalid speaker {i+1} selection.")
88
 
 
89
  voice_samples = [self.read_audio(self.available_voices[sp]) for sp in selected]
90
  if any(len(v) == 0 for v in voice_samples):
91
  raise gr.Error("Failed to load one or more voice samples.")
 
112
  return_tensors="pt"
113
  )
114
 
 
 
 
115
  start = time.time()
116
  outputs = self.model.generate(
117
  **inputs,
118
  cfg_scale=cfg_scale,
119
  tokenizer=self.processor.tokenizer,
 
120
  verbose=False
121
  )
122
 
123
+ # Extract audio
124
+ if isinstance(outputs, dict) and "audio" in outputs:
125
+ audio = outputs["audio"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  else:
127
+ audio = outputs
128
+
129
+ if torch.is_tensor(audio):
130
+ audio = audio.float().cpu().numpy()
131
+ if audio.ndim > 1:
132
+ audio = audio.squeeze()
133
+
134
+ sample_rate = 24000
135
+ audio16 = convert_to_16_bit_wav(audio)
136
+
137
+ total_dur = len(audio16) / sample_rate
138
+ log = f"✅ Generation complete in {time.time()-start:.1f}s, {total_dur:.1f}s audio"
139
 
140
  self.is_generating = False
141
+ return (sample_rate, audio16), log
142
 
 
 
 
 
 
 
 
 
 
 
143
  def load_example_scripts(self):
 
144
  examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
145
  self.example_scripts = []
 
 
146
  if not os.path.exists(examples_dir):
 
147
  return
148
+ txt_files = sorted([f for f in os.listdir(examples_dir)
149
+ if f.lower().endswith('.txt')])
 
 
 
150
  for txt_file in txt_files:
 
 
 
 
 
 
 
 
 
 
 
151
  try:
152
+ with open(os.path.join(examples_dir, txt_file), 'r', encoding='utf-8') as f:
153
  script_content = f.read().strip()
154
+ if script_content:
155
+ self.example_scripts.append([1, script_content])
 
 
 
 
 
 
 
 
 
 
 
 
156
  except Exception as e:
157
+ print(f"Error loading {txt_file}: {e}")
158
+
159
+
160
+ def convert_to_16_bit_wav(data):
161
+ if torch.is_tensor(data):
162
+ data = data.detach().cpu().numpy()
163
+ data = np.array(data)
164
+ if np.max(np.abs(data)) > 1.0:
165
+ data = data / np.max(np.abs(data))
166
+ return (data * 32767).astype(np.int16)
167
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def create_demo_interface(demo_instance: VibeVoiceDemo):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  with gr.Blocks(
171
  title="VibeVoice - AI Podcast Generator",
172
+ theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")
 
 
 
 
 
173
  ) as interface:
174
+
175
+ gr.Markdown("## 🎙️ VibeVoice Podcast Generator (Final Audio Only)")
176
+
177
+ num_speakers = gr.Slider(1, 4, value=2, step=1, label="Number of Speakers")
178
+ available_speaker_names = list(demo_instance.available_voices.keys())
179
+ default_speakers = available_speaker_names[:4]
180
+
181
+ speaker_selections = []
182
+ for i in range(4):
183
+ speaker = gr.Dropdown(
184
+ choices=available_speaker_names,
185
+ value=default_speakers[i] if i < len(default_speakers) else None,
186
+ label=f"Speaker {i+1}",
187
+ visible=(i < 2)
188
+ )
189
+ speaker_selections.append(speaker)
190
+
191
+ cfg_scale = gr.Slider(1.0, 2.0, value=1.3, step=0.05, label="CFG Scale")
192
+
193
+ script_input = gr.Textbox(
194
+ label="Podcast Script",
195
+ placeholder="Enter your script here...",
196
+ lines=10
197
+ )
198
+
199
+ generate_btn = gr.Button("🚀 Generate Podcast")
200
+ audio_output = gr.Audio(
201
+ label="Generated Podcast (Download)",
202
+ type="numpy",
203
+ show_download_button=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
+ log_output = gr.Textbox(label="Log", interactive=False, lines=5)
206
+
207
  def generate_podcast_wrapper(num_speakers, script, *speakers_and_params):
 
208
  try:
209
+ speakers = speakers_and_params[:4]
210
+ cfg_scale = speakers_and_params[4]
211
+ audio, log = demo_instance.generate_podcast(
 
 
 
 
 
 
 
 
212
  num_speakers=int(num_speakers),
213
  script=script,
214
  speaker_1=speakers[0],
 
216
  speaker_3=speakers[2],
217
  speaker_4=speakers[3],
218
  cfg_scale=cfg_scale
219
+ )
220
+ return audio, log
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  except Exception as e:
 
 
 
222
  traceback.print_exc()
223
+ return None, f"❌ Error: {str(e)}"
224
+
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  generate_btn.click(
 
 
 
 
 
226
  fn=generate_podcast_wrapper,
227
  inputs=[num_speakers, script_input] + speaker_selections + [cfg_scale],
228
+ outputs=[audio_output, log_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ return interface
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  def run_demo(
235
  model_path: str = "microsoft/VibeVoice-1.5B",
236
  device: str = "cuda",
237
  inference_steps: int = 5,
238
  share: bool = True,
239
+ ):
 
 
 
 
 
 
240
  set_seed(42)
241
+ demo_instance = VibeVoiceDemo(model_path, device, inference_steps)
 
 
 
 
 
 
 
 
 
 
242
  interface = create_demo_interface(demo_instance)
243
+ interface.queue().launch(
244
+ share=share,
245
+ server_name="0.0.0.0" if share else "127.0.0.1",
246
+ show_error=True,
247
+ show_api=False
248
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
 
251
  if __name__ == "__main__":
252
  run_demo()