|
|
|
""" |
|
Load and use the 4-bit quantized VibeVoice model |
|
""" |
|
|
|
import torch |
|
from transformers import BitsAndBytesConfig |
|
from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference |
|
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor |
|
|
|
def load_quantized_model(model_path="/home/deveraux/Desktop/vibevoice/VibeVoice-Large-4bit"): |
|
"""Load the pre-quantized VibeVoice model""" |
|
|
|
print("Loading 4-bit quantized VibeVoice model...") |
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type='nf4' |
|
) |
|
|
|
|
|
processor = VibeVoiceProcessor.from_pretrained(model_path) |
|
|
|
|
|
model = VibeVoiceForConditionalGenerationInference.from_pretrained( |
|
model_path, |
|
quantization_config=bnb_config, |
|
device_map='cuda', |
|
torch_dtype=torch.bfloat16, |
|
) |
|
|
|
model.eval() |
|
|
|
print("✅ Model loaded successfully!") |
|
print(f"💾 Memory usage: {torch.cuda.memory_allocated() / 1e9:.1f} GB") |
|
|
|
return model, processor |
|
|
|
|
|
if __name__ == "__main__": |
|
model, processor = load_quantized_model() |
|
|
|
|
|
text = "Speaker 1: Hello! Speaker 2: Hi there!" |
|
inputs = processor( |
|
text=[text], |
|
voice_samples=[["path/to/voice1.wav", "path/to/voice2.wav"]], |
|
padding=True, |
|
return_tensors="pt", |
|
) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs) |
|
|
|
|
|
processor.save_audio(outputs.speech_outputs[0], "output.wav") |
|
|