Spaces:
Runtime error
Runtime error
""" | |
VibeVoice Simple Chat Interface - Streamlined Audio Generation Demo | |
""" | |
import argparse | |
import os | |
import tempfile | |
import time | |
import threading | |
import subprocess | |
import numpy as np | |
import gradio as gr | |
import librosa | |
import soundfile as sf | |
import torch | |
from pathlib import Path | |
from typing import Iterator, Dict, Any | |
# Clone and setup VibeVoice if not already present | |
import subprocess | |
vibevoice_dir = Path('./VibeVoice') | |
if not vibevoice_dir.exists(): | |
print("Cloning VibeVoice repository...") | |
subprocess.run(['git', 'clone', 'https://github.com/microsoft/VibeVoice.git'], check=True) | |
print("Installing VibeVoice...") | |
subprocess.run(['pip', 'install', '-e', './VibeVoice'], check=True) | |
print("Installing wheel (required for flash-attn)...") | |
subprocess.run(['pip', 'install', 'wheel'], check=True) | |
print("Installing flash-attn...") | |
try: | |
subprocess.run(['pip', 'install', 'flash-attn', '--no-build-isolation'], check=True) | |
except subprocess.CalledProcessError: | |
print("Warning: flash-attn installation failed. Continuing without it...") | |
# Add the VibeVoice directory to path | |
import sys | |
sys.path.insert(0, str(vibevoice_dir)) | |
# Import VibeVoice modules | |
try: | |
# Try direct import first (if installed as package) | |
from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig | |
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference | |
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor | |
from vibevoice.modular.streamer import AudioStreamer | |
except ImportError: | |
try: | |
# Try importing from the cloned directory | |
import importlib.util | |
# Load modules directly from the VibeVoice directory | |
def load_module(module_name, file_path): | |
spec = importlib.util.spec_from_file_location(module_name, file_path) | |
module = importlib.util.module_from_spec(spec) | |
sys.modules[module_name] = module | |
spec.loader.exec_module(module) | |
return module | |
# Load each module | |
config_module = load_module( | |
"vibevoice_config", | |
vibevoice_dir / "modular" / "configuration_vibevoice.py" | |
) | |
VibeVoiceConfig = config_module.VibeVoiceConfig | |
model_module = load_module( | |
"vibevoice_model", | |
vibevoice_dir / "modular" / "modeling_vibevoice_inference.py" | |
) | |
VibeVoiceForConditionalGenerationInference = model_module.VibeVoiceForConditionalGenerationInference | |
processor_module = load_module( | |
"vibevoice_processor", | |
vibevoice_dir / "processor" / "vibevoice_processor.py" | |
) | |
VibeVoiceProcessor = processor_module.VibeVoiceProcessor | |
streamer_module = load_module( | |
"vibevoice_streamer", | |
vibevoice_dir / "modular" / "streamer.py" | |
) | |
AudioStreamer = streamer_module.AudioStreamer | |
except Exception as e: | |
raise ImportError( | |
f"VibeVoice module not found. Error: {e}\n" | |
"Please ensure VibeVoice is properly installed:\n" | |
"git clone https://github.com/microsoft/VibeVoice.git\n" | |
"cd VibeVoice/\n" | |
"pip install -e .\n" | |
) | |
from transformers.utils import logging | |
from transformers import set_seed | |
logging.set_verbosity_info() | |
logger = logging.get_logger(__name__) | |
class VibeVoiceChat: | |
def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5): | |
"""Initialize the VibeVoice chat model.""" | |
self.model_path = model_path | |
self.device = device if torch.cuda.is_available() else "cpu" | |
self.inference_steps = inference_steps | |
self.is_generating = False | |
self.stop_generation = False | |
self.current_streamer = None | |
# Check GPU availability | |
if torch.cuda.is_available(): | |
print(f"β GPU detected: {torch.cuda.get_device_name(0)}") | |
print(f" Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") | |
else: | |
print("β No GPU detected, using CPU (generation will be slower)") | |
self.load_model() | |
self.setup_voice_presets() | |
def load_model(self): | |
"""Load the VibeVoice model and processor.""" | |
print(f"Loading model from {self.model_path}") | |
self.processor = VibeVoiceProcessor.from_pretrained(self.model_path) | |
if torch.cuda.is_available(): | |
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
self.model_path, | |
torch_dtype=torch.bfloat16, | |
device_map='cuda', | |
attn_implementation="flash_attention_2", | |
) | |
else: | |
self.model = VibeVoiceForConditionalGenerationInference.from_pretrained( | |
self.model_path, | |
torch_dtype=torch.float32, | |
device_map='cpu', | |
) | |
self.model.eval() | |
# Configure noise scheduler | |
self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( | |
self.model.model.noise_scheduler.config, | |
algorithm_type='sde-dpmsolver++', | |
beta_schedule='squaredcos_cap_v2' | |
) | |
self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) | |
def setup_voice_presets(self): | |
"""Setup voice presets from the voices directory.""" | |
voices_dir = os.path.join(os.path.dirname(__file__), "voices") | |
if not os.path.exists(voices_dir): | |
print(f"Warning: Voices directory not found at {voices_dir}") | |
self.available_voices = {} | |
return | |
self.available_voices = {} | |
audio_extensions = ('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac') | |
for file in os.listdir(voices_dir): | |
if file.lower().endswith(audio_extensions): | |
name = os.path.splitext(file)[0] | |
self.available_voices[name] = os.path.join(voices_dir, file) | |
self.available_voices = dict(sorted(self.available_voices.items())) | |
print(f"Found {len(self.available_voices)} voice presets") | |
def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray: | |
"""Read and preprocess audio file.""" | |
try: | |
wav, sr = sf.read(audio_path) | |
if len(wav.shape) > 1: | |
wav = np.mean(wav, axis=1) | |
if sr != target_sr: | |
wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr) | |
return wav | |
except Exception as e: | |
print(f"Error reading audio {audio_path}: {e}") | |
return np.array([]) | |
def format_script(self, message: str, num_speakers: int = 2) -> str: | |
"""Format input message into a script with speaker assignments.""" | |
lines = message.strip().split('\n') | |
formatted_lines = [] | |
for i, line in enumerate(lines): | |
line = line.strip() | |
if not line: | |
continue | |
# Check if already formatted | |
if line.startswith('Speaker ') and ':' in line: | |
formatted_lines.append(line) | |
else: | |
# Auto-assign speakers in rotation | |
speaker_id = i % num_speakers | |
formatted_lines.append(f"Speaker {speaker_id}: {line}") | |
return '\n'.join(formatted_lines) | |
def generate_audio_stream( | |
self, | |
message: str, | |
history: list, | |
voice_1: str, | |
voice_2: str, | |
num_speakers: int, | |
cfg_scale: float | |
) -> Iterator[tuple]: | |
"""Generate audio stream from text input.""" | |
try: | |
self.stop_generation = False | |
self.is_generating = True | |
# Validate inputs | |
if not message.strip(): | |
yield None | |
return | |
# Format the script | |
formatted_script = self.format_script(message, num_speakers) | |
# Select voices based on number of speakers | |
selected_voices = [voice_1] | |
if num_speakers > 1 and voice_2: | |
selected_voices.append(voice_2) | |
# Load voice samples | |
voice_samples = [] | |
for i in range(num_speakers): | |
# Use the appropriate voice for each speaker | |
if i < len(selected_voices): | |
voice_name = selected_voices[i] | |
else: | |
# Reuse the first voice if we don't have enough | |
voice_name = selected_voices[0] if selected_voices else None | |
if voice_name and voice_name in self.available_voices: | |
audio_data = self.read_audio(self.available_voices[voice_name]) | |
if len(audio_data) > 0: | |
voice_samples.append(audio_data) | |
else: | |
# Add default audio if reading failed | |
voice_samples.append(np.zeros(24000)) | |
else: | |
# Add default audio if no voice available | |
voice_samples.append(np.zeros(24000)) | |
# Ensure we have exactly the right number of voice samples | |
voice_samples = voice_samples[:num_speakers] | |
# Process inputs | |
inputs = self.processor( | |
text=[formatted_script], | |
voice_samples=[voice_samples], | |
padding=True, | |
return_tensors="pt", | |
return_attention_mask=True, | |
) | |
# Move to device | |
if self.device == "cuda": | |
inputs = {k: v.to(self.device) if torch.is_tensor(v) else v for k, v in inputs.items()} | |
# Create audio streamer | |
audio_streamer = AudioStreamer( | |
batch_size=1, | |
stop_signal=None, | |
timeout=None | |
) | |
self.current_streamer = audio_streamer | |
# Start generation in separate thread | |
generation_thread = threading.Thread( | |
target=self._generate_with_streamer, | |
args=(inputs, cfg_scale, audio_streamer) | |
) | |
generation_thread.start() | |
# Wait briefly for generation to start | |
time.sleep(1) | |
# Stream audio chunks | |
sample_rate = 24000 | |
audio_stream = audio_streamer.get_stream(0) | |
for audio_chunk in audio_stream: | |
if self.stop_generation: | |
audio_streamer.end() | |
break | |
# Convert to numpy | |
if torch.is_tensor(audio_chunk): | |
if audio_chunk.dtype == torch.bfloat16: | |
audio_chunk = audio_chunk.float() | |
audio_np = audio_chunk.cpu().numpy().astype(np.float32) | |
else: | |
audio_np = np.array(audio_chunk, dtype=np.float32) | |
# Ensure 1D | |
if len(audio_np.shape) > 1: | |
audio_np = audio_np.squeeze() | |
# Convert to 16-bit | |
audio_16bit = self.convert_to_16_bit_wav(audio_np) | |
yield (sample_rate, audio_16bit) | |
# Wait for generation to complete | |
generation_thread.join(timeout=5.0) | |
self.current_streamer = None | |
self.is_generating = False | |
except Exception as e: | |
print(f"Error in generation: {e}") | |
import traceback | |
traceback.print_exc() | |
self.is_generating = False | |
self.current_streamer = None | |
yield None | |
def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer): | |
"""Helper method to run generation with streamer.""" | |
try: | |
def check_stop(): | |
return self.stop_generation | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=None, | |
cfg_scale=cfg_scale, | |
tokenizer=self.processor.tokenizer, | |
generation_config={'do_sample': False}, | |
audio_streamer=audio_streamer, | |
stop_check_fn=check_stop, | |
verbose=False, | |
refresh_negative=True, | |
) | |
except Exception as e: | |
print(f"Error in generation thread: {e}") | |
import traceback | |
traceback.print_exc() | |
audio_streamer.end() | |
def convert_to_16_bit_wav(self, data): | |
"""Convert audio data to 16-bit WAV format.""" | |
if torch.is_tensor(data): | |
data = data.detach().cpu().numpy() | |
data = np.array(data) | |
if np.max(np.abs(data)) > 1.0: | |
data = data / np.max(np.abs(data)) | |
data = (data * 32767).astype(np.int16) | |
return data | |
def stop_audio_generation(self): | |
"""Stop the current audio generation.""" | |
self.stop_generation = True | |
if self.current_streamer: | |
try: | |
self.current_streamer.end() | |
except: | |
pass | |
def create_chat_interface(chat_instance: VibeVoiceChat): | |
"""Create a simplified Gradio ChatInterface for VibeVoice.""" | |
# Get available voices | |
voice_options = list(chat_instance.available_voices.keys()) if chat_instance.available_voices else ["None"] | |
default_voice_1 = voice_options[0] if len(voice_options) > 0 else "None" | |
default_voice_2 = voice_options[1] if len(voice_options) > 1 else voice_options[0] | |
# Define the chat function | |
def chat_fn(message: Dict[str, Any], history: list, voice_1: str, voice_2: str, num_speakers: int, cfg_scale: float): | |
"""Process chat message and generate audio response.""" | |
# Extract text from message (handle both string and dict inputs) | |
if isinstance(message, dict): | |
text = message.get("text", "") | |
else: | |
text = message | |
if not text.strip(): | |
return gr.Audio(value=None) | |
try: | |
# Generate audio stream | |
audio_generator = chat_instance.generate_audio_stream( | |
text, history, voice_1, voice_2, num_speakers, cfg_scale | |
) | |
# Get the first audio chunk for immediate response | |
audio_data = None | |
for audio_chunk in audio_generator: | |
if audio_chunk is not None: | |
audio_data = audio_chunk | |
break | |
# Return audio component | |
if audio_data: | |
return gr.Audio(value=audio_data, streaming=True, autoplay=True) | |
else: | |
return gr.Audio(value=None) | |
except Exception as e: | |
print(f"Error in chat_fn: {e}") | |
import traceback | |
traceback.print_exc() | |
return gr.Audio(value=None) | |
# Create additional inputs | |
additional_inputs = [ | |
gr.Dropdown( | |
choices=voice_options, | |
value=default_voice_1, | |
label="Voice 1", | |
info="Select voice for Speaker 0" | |
), | |
gr.Dropdown( | |
choices=voice_options, | |
value=default_voice_2, | |
label="Voice 2", | |
info="Select voice for Speaker 1 (if using multiple speakers)" | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=2, | |
value=2, | |
step=1, | |
label="Number of Speakers", | |
info="Number of speakers in the dialogue" | |
), | |
gr.Slider( | |
minimum=1.0, | |
maximum=2.0, | |
value=1.3, | |
step=0.05, | |
label="CFG Scale", | |
info="Guidance strength (higher = more adherence to text)" | |
) | |
] | |
# Create the ChatInterface without examples to avoid the error | |
interface = gr.ChatInterface( | |
fn=chat_fn, | |
type="messages", | |
title="ποΈ VibeVoice Chat", | |
description="Generate natural dialogue audio with AI voices. Type your message or paste a script!", | |
additional_inputs=additional_inputs, | |
additional_inputs_accordion=gr.Accordion(label="Voice & Generation Settings", open=True), | |
submit_btn="π΅ Generate Audio", | |
stop_btn="βΉοΈ Stop", | |
autofocus=True, | |
autoscroll=True, | |
show_progress="minimal", | |
theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="purple" | |
), | |
css=""" | |
.gradio-container { | |
max-width: 1200px; | |
margin: auto; | |
} | |
.message { | |
font-size: 1.1em; | |
} | |
""", | |
analytics_enabled=True, | |
fill_height=True, | |
fill_width=False, | |
) | |
return interface | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="VibeVoice Chat Interface") | |
parser.add_argument( | |
"--model_path", | |
type=str, | |
default="microsoft/VibeVoice-1.5B", | |
help="Path to the VibeVoice model", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda" if torch.cuda.is_available() else "cpu", | |
help="Device for inference", | |
) | |
parser.add_argument( | |
"--inference_steps", | |
type=int, | |
default=10, | |
help="Number of DDPM inference steps", | |
) | |
return parser.parse_args() | |
def main(): | |
"""Main function to run the chat interface.""" | |
args = parse_args() | |
set_seed(42) | |
print("ποΈ Initializing VibeVoice Chat Interface...") | |
# Initialize chat instance | |
chat_instance = VibeVoiceChat( | |
model_path=args.model_path, | |
device=args.device, | |
inference_steps=args.inference_steps | |
) | |
# Create interface | |
interface = create_chat_interface(chat_instance) | |
print(f"π Launching chat interface") | |
print(f"π Model: {args.model_path}") | |
print(f"π» Device: {chat_instance.device}") | |
print(f"π Available voices: {len(chat_instance.available_voices)}") | |
# Launch the interface | |
interface.launch( | |
show_error=True, | |
quiet=False, | |
) | |
if __name__ == "__main__": | |
main() |