VibeVoice-Colab / processor /vibevoice_processor.py
NeuralFalcon's picture
Upload 34 files
909e36b verified
raw
history blame
30.2 kB
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re
import numpy as np
import torch
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType, logging
from .vibevoice_tokenizer_processor import AudioNormalizer
logger = logging.get_logger(__name__)
class VibeVoiceProcessor:
r"""
Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.
[`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`].
See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.
Args:
tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):
The tokenizer for text processing.
audio_processor (`VibeVoiceTokenizerProcessor`):
The audio processor for speech processing.
speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):
The compression ratio for speech tokenization.
db_normalize (`bool`, *optional*, defaults to True):
Whether to apply decibel normalization to audio inputs.
"""
def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
self.tokenizer = tokenizer
self.audio_processor = audio_processor
self.speech_tok_compress_ratio = speech_tok_compress_ratio
self.db_normalize = db_normalize
self.audio_normalizer = AudioNormalizer() if db_normalize else None
self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model
- a path to a *directory* containing processor config
Returns:
[`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.
"""
import os
import json
from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
from vibevoice.modular.modular_vibevoice_text_tokenizer import (
VibeVoiceTextTokenizer,
VibeVoiceTextTokenizerFast
)
# Load processor configuration
config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
else:
logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
config = {
"speech_tok_compress_ratio": 3200,
"db_normalize": True,
}
# Extract main processor parameters
speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
db_normalize = config.get("db_normalize", True)
# Load tokenizer - try from model path first, then fallback to Qwen
language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
if 'qwen' in language_model_pretrained_name.lower():
tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
language_model_pretrained_name,
**kwargs
)
else:
raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
# Load audio processor
if "audio_processor" in config:
# Create audio processor from config
audio_config = config["audio_processor"]
audio_processor = VibeVoiceTokenizerProcessor(
sampling_rate=audio_config.get("sampling_rate", 24000),
normalize_audio=audio_config.get("normalize_audio", True),
target_dB_FS=audio_config.get("target_dB_FS", -25),
eps=audio_config.get("eps", 1e-6),
)
else:
# Create default audio processor
audio_processor = VibeVoiceTokenizerProcessor()
# Create and return the processor
return cls(
tokenizer=tokenizer,
audio_processor=audio_processor,
speech_tok_compress_ratio=speech_tok_compress_ratio,
db_normalize=db_normalize,
)
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
"""
Save a processor to a directory, so that it can be re-loaded using the
[`~VibeVoiceProcessor.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the processor will be saved.
"""
import os
import json
os.makedirs(save_directory, exist_ok=True)
# Save processor configuration
processor_config = {
"processor_class": "VibeVoiceProcessor",
"speech_tok_compress_ratio": self.speech_tok_compress_ratio,
"db_normalize": self.db_normalize,
"audio_processor": {
"feature_extractor_type": "VibeVoiceTokenizerProcessor",
"sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
"normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
"target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
"eps": getattr(self.audio_processor, 'eps', 1e-6),
}
}
config_path = os.path.join(save_directory, "preprocessor_config.json")
with open(config_path, 'w') as f:
json.dump(processor_config, f, indent=2)
logger.info(f"Processor configuration saved in {config_path}")
def __call__(
self,
text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
**kwargs,
) -> BatchEncoding:
"""
Main method to process one or more podcast scripts with optional voice samples.
Args:
text (`str`, `List[str]`):
The input text(s) to process. Can be:
- A single script string
- A list of script strings for batch processing
- A path to a .json or .txt file
- A list of paths
voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):
Voice samples for each script. Can be:
- A list of samples for a single script
- A list of lists for batch processing
padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):
Whether to pad sequences to the same length
truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):
Whether to truncate sequences
max_length (`int`, *optional*):
Maximum length of the returned sequences
return_tensors (`str` or `TensorType`, *optional*):
If set, will return tensors of a particular framework
return_attention_mask (`bool`, defaults to `True`):
Whether to return the attention mask
Returns:
`BatchEncoding`: A BatchEncoding with the following fields:
- **input_ids** -- List of token id sequences or tensor
- **attention_mask** -- List of attention masks or tensor
- **speech_tensors** -- Padded speech inputs (if voice_samples provided)
- **speech_masks** -- Speech masks (if voice_samples provided)
- **speech_input_mask** -- Boolean masks indicating speech token positions
"""
# Handle single vs batch input
if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
# Single input
texts = [text]
is_batched = False
else:
# Batch input
texts = text
is_batched = True
# Handle voice samples
if voice_samples is not None:
if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
# Single set of voice samples
voice_samples_list = [voice_samples]
else:
# Batch of voice samples
voice_samples_list = voice_samples
else:
voice_samples_list = [None] * len(texts)
# Process each input
all_encodings = []
for text_input, voice_input in zip(texts, voice_samples_list):
encoding = self._process_single(text_input, voice_input)
all_encodings.append(encoding)
# Combine batch
batch_encoding = self._batch_encode(
all_encodings,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
return_attention_mask=return_attention_mask,
)
return batch_encoding
def _process_single(
self,
text: Union[str, TextInput],
voice_samples: Optional[List[Union[str, np.ndarray]]] = None,
) -> Dict[str, Any]:
"""Process a single podcast script."""
# Determine if text is a file path or direct script
script = None
if isinstance(text, str):
# Check if it's a file path
if text.endswith('.json') and os.path.exists(text):
script = self._convert_json_to_script(text)
elif text.endswith('.txt') and os.path.exists(text):
script = self._convert_text_to_script(text)
else:
# Assume it's the script content directly
script = text
if script is None:
raise ValueError(f"Could not process input text: {text}")
# Parse the script
parsed_lines = self._parse_script(script)
all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
# Create system prompt
# system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
system_tokens = self.tokenizer.encode(self.system_prompt)
# Process voice samples if provided
if voice_samples:
voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
else:
voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
# Build full token sequence
full_tokens = system_tokens + voice_tokens
speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
# Add text input section
full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
for speaker_id, speaker_text in parsed_lines:
speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
full_tokens += speaker_text_tokens
speech_input_mask += [False] * len(speaker_text_tokens)
# Add speech output section
full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
return {
"input_ids": full_tokens,
"speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
"speech_input_mask": speech_input_mask,
"parsed_script": parsed_lines,
"all_speakers": all_speakers,
}
def _batch_encode(
self,
encodings: List[Dict[str, Any]],
padding: Union[bool, str, PaddingStrategy] = True,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_attention_mask: bool = True,
) -> BatchEncoding:
"""Combine multiple encodings into a batch with padding."""
# Extract input_ids and create attention_mask
input_ids_list = [enc["input_ids"] for enc in encodings]
speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
# Determine padding strategy
if isinstance(padding, bool):
padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
elif isinstance(padding, str):
padding_strategy = PaddingStrategy(padding)
else:
padding_strategy = padding
# Apply padding to input_ids
if padding_strategy != PaddingStrategy.DO_NOT_PAD:
if padding_strategy == PaddingStrategy.LONGEST:
max_len = max(len(ids) for ids in input_ids_list)
elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
max_len = max_length
else:
max_len = max(len(ids) for ids in input_ids_list)
# Pad sequences
padded_input_ids = []
attention_masks = []
padded_speech_input_masks = []
for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
# Truncate if needed
if truncation and len(input_ids) > max_len:
input_ids = input_ids[:max_len]
speech_mask = speech_mask[:max_len]
# Pad
padding_length = max_len - len(input_ids)
# padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
attention_mask = [0] * padding_length + [1] * len(input_ids)
padded_speech_mask = [False] * padding_length + speech_mask
padded_input_ids.append(padded_ids)
attention_masks.append(attention_mask)
padded_speech_input_masks.append(padded_speech_mask)
input_ids_list = padded_input_ids
speech_input_masks_list = padded_speech_input_masks
else:
# No padding, just create attention masks
attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
# Process speech inputs
all_speech_inputs = []
has_speech = False
for enc in encodings:
if enc["speech_inputs"] is not None:
all_speech_inputs.extend(enc["speech_inputs"])
has_speech = True
# Prepare batch encoding
batch_encoding = BatchEncoding()
# Handle tensor conversion
if return_tensors is not None:
batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
else:
batch_encoding["input_ids"] = input_ids_list
if return_attention_mask and attention_masks is not None:
batch_encoding["attention_mask"] = attention_masks
batch_encoding["speech_input_mask"] = speech_input_masks_list
# Process speech tensors if present
if has_speech:
speech_dict = self.prepare_speech_inputs(
all_speech_inputs,
return_tensors=return_tensors,
)
batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
batch_encoding["speech_masks"] = speech_dict["speech_masks"]
else:
batch_encoding["speech_tensors"] = None
batch_encoding["speech_masks"] = None
# Add metadata
batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
return batch_encoding
def _create_voice_prompt(
self,
speaker_samples: List[Union[str, np.ndarray]]
) -> Tuple[List[int], List[np.ndarray], List[bool]]:
"""
Create voice prompt tokens and process audio samples.
Returns:
tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)
"""
vae_token_id = self.tokenizer.speech_diffusion_id
voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
voice_speech_inputs = []
voice_speech_masks = [False] * len(voice_full_tokens)
for speaker_id, speaker_audio in enumerate(speaker_samples):
prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
# Process audio
if isinstance(speaker_audio, str):
# Load audio from file
wav = self.audio_processor._load_audio_from_path(speaker_audio)
else:
wav = np.array(speaker_audio, dtype=np.float32)
# Apply normalization if needed
if self.db_normalize and self.audio_normalizer:
wav = self.audio_normalizer(wav)
# Calculate token length based on compression ratio
# if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
# vae_tok_len = wav.shape[0]
# else:
vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
# Build tokens and masks
speaker_tokens = (prefix_tokens +
[self.tokenizer.speech_start_id] +
[vae_token_id] * vae_tok_len +
[self.tokenizer.speech_end_id] +
self.tokenizer.encode('\n', add_special_tokens=False))
vae_input_mask = ([False] * len(prefix_tokens) +
[False] +
[True] * vae_tok_len +
[False] +
[False])
voice_full_tokens.extend(speaker_tokens)
voice_speech_masks.extend(vae_input_mask)
voice_speech_inputs.append(wav)
return voice_full_tokens, voice_speech_inputs, voice_speech_masks
def prepare_speech_inputs(
self,
speech_inputs: List[np.ndarray],
return_tensors: Optional[Union[str, TensorType]] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> Dict[str, Any]:
"""
Prepare speech inputs for model consumption.
Args:
speech_inputs: List of speech arrays
return_tensors: Output tensor type
device: Device to place tensors on
dtype: Data type for tensors
Returns:
Dictionary with padded_speeches and speech_masks
"""
if not speech_inputs:
return {"padded_speeches": None, "speech_masks": None}
# Calculate sequence lengths
vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
# vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
max_speech_length = max(s.shape[0] for s in speech_inputs)
# Pad speeches
if speech_inputs[0].ndim == 1:
padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
else:
padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
padded_speeches[i, :len(speech)] = speech
speech_masks[i, :vae_tok_length] = True
result = {
"padded_speeches": padded_speeches,
"speech_masks": speech_masks,
}
# Convert to tensors if requested
if return_tensors == "pt":
result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
return result
def _convert_json_to_script(self, json_file: str) -> str:
"""
Convert JSON format to script format.
Expected JSON format:
[
{"speaker": "1", "text": "Hello everyone..."},
{"speaker": "2", "text": "Great to be here..."}
]
"""
import json
with open(json_file, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON file must contain a list of speaker entries")
script_lines = []
for item in data:
if not isinstance(item, dict):
logger.warning(f"Skipping non-dict entry: {item}")
continue
speaker = item.get('speaker')
text = item.get('text')
if speaker is None or text is None:
logger.warning(f"Skipping entry missing speaker or text: {item}")
continue
# Ensure speaker ID is valid
try:
speaker_id = int(speaker)
except (ValueError, TypeError):
logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
continue
# Clean up text
text = text.strip()
if text:
script_lines.append(f"Speaker {speaker_id}: {text}")
if not script_lines:
raise ValueError("No valid entries found in JSON file")
return "\n".join(script_lines)
def _convert_text_to_script(self, text_file: str) -> str:
"""
Convert text file to script format.
Handles multiple formats:
1. Already formatted as "Speaker X: text"
2. Plain text (assigns to Speaker 1)
Handles edge cases like multiple colons in a line.
"""
with open(text_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
script_lines = []
current_speaker = 1
for line in lines:
line = line.strip()
if not line:
continue
# Try to parse as "Speaker X: text" format
# Use regex to be more robust
speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
if speaker_match:
speaker_id = int(speaker_match.group(1))
text = speaker_match.group(2).strip()
if text:
script_lines.append(f"Speaker {speaker_id}: {text}")
else:
# Treat as plain text - assign to current speaker
script_lines.append(f"Speaker {current_speaker}: {line}")
if not script_lines:
raise ValueError("No valid content found in text file")
return "\n".join(script_lines)
def _parse_script(self, script: str) -> List[Tuple[int, str]]:
"""Parse script into list of (speaker_id, text) tuples."""
lines = script.strip().split("\n")
parsed_lines = []
speaker_ids = []
# First pass: parse all lines and collect speaker IDs
for line in lines:
if not line.strip():
continue
# Use regex to handle edge cases like multiple colons
match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
if match:
speaker_id = int(match.group(1))
text = ' ' + match.group(2).strip()
parsed_lines.append((speaker_id, text))
speaker_ids.append(speaker_id)
else:
logger.warning(f"Could not parse line: '{line}'")
if not parsed_lines:
raise ValueError("No valid speaker lines found in script")
# Check if we need to normalize speaker IDs (only if all are > 0)
min_speaker_id = min(speaker_ids)
if min_speaker_id > 0:
# Normalize to start from 0
normalized_lines = []
for speaker_id, text in parsed_lines:
normalized_lines.append((speaker_id - 1, text))
return normalized_lines
else:
# Keep original IDs
return parsed_lines
def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
"""Merge text and audio inputs into a single BatchEncoding."""
# Start with text inputs
merged = BatchEncoding(text_inputs)
# Add audio-specific fields
if "audio" in audio_inputs:
merged["speech_inputs"] = audio_inputs["audio"]
if "streaming" in audio_inputs:
merged["streaming"] = audio_inputs["streaming"]
return merged
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].
Please refer to the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
def model_input_names(self):
"""
Return the list of inputs accepted by the model.
"""
tokenizer_input_names = self.tokenizer.model_input_names
audio_processor_input_names = self.audio_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))
def save_audio(self,
audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],
output_path: str = "output.wav",
sampling_rate: Optional[int] = None,
normalize: bool = False,
batch_prefix: str = "audio_",
) -> str:
"""
Save audio data to a file.
Args:
audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):
The audio data to save. Can be a single tensor/array or a list of them.
output_path (str, optional): Path to save the audio file. Defaults to "output.wav".
sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.
normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.
batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".
Returns:
str: The path to the saved audio file.
"""
return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
__all__ = [
"VibeVoiceProcessor",
]