import math import warnings from typing import List, Optional, Union, Dict, Any, Tuple import os import re import numpy as np import torch from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy from transformers.utils import TensorType, logging from .vibevoice_tokenizer_processor import AudioNormalizer logger = logging.get_logger(__name__) class VibeVoiceProcessor: r""" Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor. [`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`]. See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information. Args: tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`): The tokenizer for text processing. audio_processor (`VibeVoiceTokenizerProcessor`): The audio processor for speech processing. speech_tok_compress_ratio (`int`, *optional*, defaults to 3200): The compression ratio for speech tokenization. db_normalize (`bool`, *optional*, defaults to True): Whether to apply decibel normalization to audio inputs. """ def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs): self.tokenizer = tokenizer self.audio_processor = audio_processor self.speech_tok_compress_ratio = speech_tok_compress_ratio self.db_normalize = db_normalize self.audio_normalizer = AudioNormalizer() if db_normalize else None self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n" @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): This can be either: - a string, the *model id* of a pretrained model - a path to a *directory* containing processor config Returns: [`VibeVoiceProcessor`]: The processor object instantiated from pretrained model. """ import os import json from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor from vibevoice.modular.modular_vibevoice_text_tokenizer import ( VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast ) # Load processor configuration config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") if os.path.exists(config_path): with open(config_path, 'r') as f: config = json.load(f) else: logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults") config = { "speech_tok_compress_ratio": 3200, "db_normalize": True, } # Extract main processor parameters speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) db_normalize = config.get("db_normalize", True) # Load tokenizer - try from model path first, then fallback to Qwen language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B") logger.info(f"Loading tokenizer from {language_model_pretrained_name}") if 'qwen' in language_model_pretrained_name.lower(): tokenizer = VibeVoiceTextTokenizerFast.from_pretrained( language_model_pretrained_name, **kwargs ) else: raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.") # Load audio processor if "audio_processor" in config: # Create audio processor from config audio_config = config["audio_processor"] audio_processor = VibeVoiceTokenizerProcessor( sampling_rate=audio_config.get("sampling_rate", 24000), normalize_audio=audio_config.get("normalize_audio", True), target_dB_FS=audio_config.get("target_dB_FS", -25), eps=audio_config.get("eps", 1e-6), ) else: # Create default audio processor audio_processor = VibeVoiceTokenizerProcessor() # Create and return the processor return cls( tokenizer=tokenizer, audio_processor=audio_processor, speech_tok_compress_ratio=speech_tok_compress_ratio, db_normalize=db_normalize, ) def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): """ Save a processor to a directory, so that it can be re-loaded using the [`~VibeVoiceProcessor.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the processor will be saved. """ import os import json os.makedirs(save_directory, exist_ok=True) # Save processor configuration processor_config = { "processor_class": "VibeVoiceProcessor", "speech_tok_compress_ratio": self.speech_tok_compress_ratio, "db_normalize": self.db_normalize, "audio_processor": { "feature_extractor_type": "VibeVoiceTokenizerProcessor", "sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000), "normalize_audio": getattr(self.audio_processor, 'normalize_audio', True), "target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25), "eps": getattr(self.audio_processor, 'eps', 1e-6), } } config_path = os.path.join(save_directory, "preprocessor_config.json") with open(config_path, 'w') as f: json.dump(processor_config, f, indent=2) logger.info(f"Processor configuration saved in {config_path}") def __call__( self, text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None, padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: bool = True, **kwargs, ) -> BatchEncoding: """ Main method to process one or more podcast scripts with optional voice samples. Args: text (`str`, `List[str]`): The input text(s) to process. Can be: - A single script string - A list of script strings for batch processing - A path to a .json or .txt file - A list of paths voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*): Voice samples for each script. Can be: - A list of samples for a single script - A list of lists for batch processing padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`): Whether to pad sequences to the same length truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`): Whether to truncate sequences max_length (`int`, *optional*): Maximum length of the returned sequences return_tensors (`str` or `TensorType`, *optional*): If set, will return tensors of a particular framework return_attention_mask (`bool`, defaults to `True`): Whether to return the attention mask Returns: `BatchEncoding`: A BatchEncoding with the following fields: - **input_ids** -- List of token id sequences or tensor - **attention_mask** -- List of attention masks or tensor - **speech_tensors** -- Padded speech inputs (if voice_samples provided) - **speech_masks** -- Speech masks (if voice_samples provided) - **speech_input_mask** -- Boolean masks indicating speech token positions """ # Handle single vs batch input if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)): # Single input texts = [text] is_batched = False else: # Batch input texts = text is_batched = True # Handle voice samples if voice_samples is not None: if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))): # Single set of voice samples voice_samples_list = [voice_samples] else: # Batch of voice samples voice_samples_list = voice_samples else: voice_samples_list = [None] * len(texts) # Process each input all_encodings = [] for text_input, voice_input in zip(texts, voice_samples_list): encoding = self._process_single(text_input, voice_input) all_encodings.append(encoding) # Combine batch batch_encoding = self._batch_encode( all_encodings, padding=padding, truncation=truncation, max_length=max_length, return_tensors=return_tensors, return_attention_mask=return_attention_mask, ) return batch_encoding def _process_single( self, text: Union[str, TextInput], voice_samples: Optional[List[Union[str, np.ndarray]]] = None, ) -> Dict[str, Any]: """Process a single podcast script.""" # Determine if text is a file path or direct script script = None if isinstance(text, str): # Check if it's a file path if text.endswith('.json') and os.path.exists(text): script = self._convert_json_to_script(text) elif text.endswith('.txt') and os.path.exists(text): script = self._convert_text_to_script(text) else: # Assume it's the script content directly script = text if script is None: raise ValueError(f"Could not process input text: {text}") # Parse the script parsed_lines = self._parse_script(script) all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines)) # Create system prompt # system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False) system_tokens = self.tokenizer.encode(self.system_prompt) # Process voice samples if provided if voice_samples: voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)]) else: voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], [] # Build full token sequence full_tokens = system_tokens + voice_tokens speech_input_mask = [False] * len(system_tokens) + voice_speech_masks # Add text input section full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False) speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False)) for speaker_id, speaker_text in parsed_lines: speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False) full_tokens += speaker_text_tokens speech_input_mask += [False] * len(speaker_text_tokens) # Add speech output section full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id] speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1) return { "input_ids": full_tokens, "speech_inputs": voice_speech_inputs if voice_speech_inputs else None, "speech_input_mask": speech_input_mask, "parsed_script": parsed_lines, "all_speakers": all_speakers, } def _batch_encode( self, encodings: List[Dict[str, Any]], padding: Union[bool, str, PaddingStrategy] = True, truncation: Union[bool, str, TruncationStrategy] = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_attention_mask: bool = True, ) -> BatchEncoding: """Combine multiple encodings into a batch with padding.""" # Extract input_ids and create attention_mask input_ids_list = [enc["input_ids"] for enc in encodings] speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] # Determine padding strategy if isinstance(padding, bool): padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD elif isinstance(padding, str): padding_strategy = PaddingStrategy(padding) else: padding_strategy = padding # Apply padding to input_ids if padding_strategy != PaddingStrategy.DO_NOT_PAD: if padding_strategy == PaddingStrategy.LONGEST: max_len = max(len(ids) for ids in input_ids_list) elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: max_len = max_length else: max_len = max(len(ids) for ids in input_ids_list) # Pad sequences padded_input_ids = [] attention_masks = [] padded_speech_input_masks = [] for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list): # Truncate if needed if truncation and len(input_ids) > max_len: input_ids = input_ids[:max_len] speech_mask = speech_mask[:max_len] # Pad padding_length = max_len - len(input_ids) # padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids attention_mask = [0] * padding_length + [1] * len(input_ids) padded_speech_mask = [False] * padding_length + speech_mask padded_input_ids.append(padded_ids) attention_masks.append(attention_mask) padded_speech_input_masks.append(padded_speech_mask) input_ids_list = padded_input_ids speech_input_masks_list = padded_speech_input_masks else: # No padding, just create attention masks attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None # Process speech inputs all_speech_inputs = [] has_speech = False for enc in encodings: if enc["speech_inputs"] is not None: all_speech_inputs.extend(enc["speech_inputs"]) has_speech = True # Prepare batch encoding batch_encoding = BatchEncoding() # Handle tensor conversion if return_tensors is not None: batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) if return_attention_mask and attention_masks is not None: batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long) batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool) else: batch_encoding["input_ids"] = input_ids_list if return_attention_mask and attention_masks is not None: batch_encoding["attention_mask"] = attention_masks batch_encoding["speech_input_mask"] = speech_input_masks_list # Process speech tensors if present if has_speech: speech_dict = self.prepare_speech_inputs( all_speech_inputs, return_tensors=return_tensors, ) batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] batch_encoding["speech_masks"] = speech_dict["speech_masks"] else: batch_encoding["speech_tensors"] = None batch_encoding["speech_masks"] = None # Add metadata batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings] batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings] return batch_encoding def _create_voice_prompt( self, speaker_samples: List[Union[str, np.ndarray]] ) -> Tuple[List[int], List[np.ndarray], List[bool]]: """ Create voice prompt tokens and process audio samples. Returns: tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks) """ vae_token_id = self.tokenizer.speech_diffusion_id voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False) voice_speech_inputs = [] voice_speech_masks = [False] * len(voice_full_tokens) for speaker_id, speaker_audio in enumerate(speaker_samples): prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False) # Process audio if isinstance(speaker_audio, str): # Load audio from file wav = self.audio_processor._load_audio_from_path(speaker_audio) else: wav = np.array(speaker_audio, dtype=np.float32) # Apply normalization if needed if self.db_normalize and self.audio_normalizer: wav = self.audio_normalizer(wav) # Calculate token length based on compression ratio # if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'): # vae_tok_len = wav.shape[0] # else: vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio) # Build tokens and masks speaker_tokens = (prefix_tokens + [self.tokenizer.speech_start_id] + [vae_token_id] * vae_tok_len + [self.tokenizer.speech_end_id] + self.tokenizer.encode('\n', add_special_tokens=False)) vae_input_mask = ([False] * len(prefix_tokens) + [False] + [True] * vae_tok_len + [False] + [False]) voice_full_tokens.extend(speaker_tokens) voice_speech_masks.extend(vae_input_mask) voice_speech_inputs.append(wav) return voice_full_tokens, voice_speech_inputs, voice_speech_masks def prepare_speech_inputs( self, speech_inputs: List[np.ndarray], return_tensors: Optional[Union[str, TensorType]] = None, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> Dict[str, Any]: """ Prepare speech inputs for model consumption. Args: speech_inputs: List of speech arrays return_tensors: Output tensor type device: Device to place tensors on dtype: Data type for tensors Returns: Dictionary with padded_speeches and speech_masks """ if not speech_inputs: return {"padded_speeches": None, "speech_masks": None} # Calculate sequence lengths vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs] # vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs] max_speech_length = max(s.shape[0] for s in speech_inputs) # Pad speeches if speech_inputs[0].ndim == 1: padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32) else: padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32) speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_) for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)): padded_speeches[i, :len(speech)] = speech speech_masks[i, :vae_tok_length] = True result = { "padded_speeches": padded_speeches, "speech_masks": speech_masks, } # Convert to tensors if requested if return_tensors == "pt": result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32) result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool) return result def _convert_json_to_script(self, json_file: str) -> str: """ Convert JSON format to script format. Expected JSON format: [ {"speaker": "1", "text": "Hello everyone..."}, {"speaker": "2", "text": "Great to be here..."} ] """ import json with open(json_file, 'r', encoding='utf-8') as f: data = json.load(f) if not isinstance(data, list): raise ValueError("JSON file must contain a list of speaker entries") script_lines = [] for item in data: if not isinstance(item, dict): logger.warning(f"Skipping non-dict entry: {item}") continue speaker = item.get('speaker') text = item.get('text') if speaker is None or text is None: logger.warning(f"Skipping entry missing speaker or text: {item}") continue # Ensure speaker ID is valid try: speaker_id = int(speaker) except (ValueError, TypeError): logger.warning(f"Invalid speaker ID: {speaker}, skipping entry") continue # Clean up text text = text.strip() if text: script_lines.append(f"Speaker {speaker_id}: {text}") if not script_lines: raise ValueError("No valid entries found in JSON file") return "\n".join(script_lines) def _convert_text_to_script(self, text_file: str) -> str: """ Convert text file to script format. Handles multiple formats: 1. Already formatted as "Speaker X: text" 2. Plain text (assigns to Speaker 1) Handles edge cases like multiple colons in a line. """ with open(text_file, 'r', encoding='utf-8') as f: lines = f.readlines() script_lines = [] current_speaker = 1 for line in lines: line = line.strip() if not line: continue # Try to parse as "Speaker X: text" format # Use regex to be more robust speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE) if speaker_match: speaker_id = int(speaker_match.group(1)) text = speaker_match.group(2).strip() if text: script_lines.append(f"Speaker {speaker_id}: {text}") else: # Treat as plain text - assign to current speaker script_lines.append(f"Speaker {current_speaker}: {line}") if not script_lines: raise ValueError("No valid content found in text file") return "\n".join(script_lines) def _parse_script(self, script: str) -> List[Tuple[int, str]]: """Parse script into list of (speaker_id, text) tuples.""" lines = script.strip().split("\n") parsed_lines = [] speaker_ids = [] # First pass: parse all lines and collect speaker IDs for line in lines: if not line.strip(): continue # Use regex to handle edge cases like multiple colons match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE) if match: speaker_id = int(match.group(1)) text = ' ' + match.group(2).strip() parsed_lines.append((speaker_id, text)) speaker_ids.append(speaker_id) else: logger.warning(f"Could not parse line: '{line}'") if not parsed_lines: raise ValueError("No valid speaker lines found in script") # Check if we need to normalize speaker IDs (only if all are > 0) min_speaker_id = min(speaker_ids) if min_speaker_id > 0: # Normalize to start from 0 normalized_lines = [] for speaker_id, text in parsed_lines: normalized_lines.append((speaker_id - 1, text)) return normalized_lines else: # Keep original IDs return parsed_lines def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding: """Merge text and audio inputs into a single BatchEncoding.""" # Start with text inputs merged = BatchEncoding(text_inputs) # Add audio-specific fields if "audio" in audio_inputs: merged["speech_inputs"] = audio_inputs["audio"] if "streaming" in audio_inputs: merged["streaming"] = audio_inputs["streaming"] return merged def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) @property def model_input_names(self): """ Return the list of inputs accepted by the model. """ tokenizer_input_names = self.tokenizer.model_input_names audio_processor_input_names = self.audio_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"])) def save_audio(self, audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], output_path: str = "output.wav", sampling_rate: Optional[int] = None, normalize: bool = False, batch_prefix: str = "audio_", ) -> str: """ Save audio data to a file. Args: audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]): The audio data to save. Can be a single tensor/array or a list of them. output_path (str, optional): Path to save the audio file. Defaults to "output.wav". sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default. normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False. batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_". Returns: str: The path to the saved audio file. """ return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix) __all__ = [ "VibeVoiceProcessor", ]