File size: 30,170 Bytes
909e36b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
import math
import warnings
from typing import List, Optional, Union, Dict, Any, Tuple
import os
import re

import numpy as np
import torch

from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import TensorType, logging
from .vibevoice_tokenizer_processor import AudioNormalizer

logger = logging.get_logger(__name__)


class VibeVoiceProcessor:
    r"""

    Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor.



    [`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`]. 

    See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information.



    Args:

        tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`):

            The tokenizer for text processing.

        audio_processor (`VibeVoiceTokenizerProcessor`):

            The audio processor for speech processing.

        speech_tok_compress_ratio (`int`, *optional*, defaults to 3200):

            The compression ratio for speech tokenization.

        db_normalize (`bool`, *optional*, defaults to True):

            Whether to apply decibel normalization to audio inputs.

    """

    def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs):
        self.tokenizer = tokenizer
        self.audio_processor = audio_processor
        self.speech_tok_compress_ratio = speech_tok_compress_ratio
        self.db_normalize = db_normalize
        self.audio_normalizer = AudioNormalizer() if db_normalize else None
        self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n"

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        """

        Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor.



        Args:

            pretrained_model_name_or_path (`str` or `os.PathLike`):

                This can be either:

                - a string, the *model id* of a pretrained model

                - a path to a *directory* containing processor config



        Returns:

            [`VibeVoiceProcessor`]: The processor object instantiated from pretrained model.

        """
        import os
        import json
        from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor
        from vibevoice.modular.modular_vibevoice_text_tokenizer import (
            VibeVoiceTextTokenizer, 
            VibeVoiceTextTokenizerFast
        )
        
        # Load processor configuration
        config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json")
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                config = json.load(f)
        else:
            logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults")
            config = {
                "speech_tok_compress_ratio": 3200,
                "db_normalize": True,
            }
        
        # Extract main processor parameters
        speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200)
        db_normalize = config.get("db_normalize", True)
        
        # Load tokenizer - try from model path first, then fallback to Qwen        
        language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B")
        logger.info(f"Loading tokenizer from {language_model_pretrained_name}")
        if 'qwen' in language_model_pretrained_name.lower():
            tokenizer = VibeVoiceTextTokenizerFast.from_pretrained(
                language_model_pretrained_name,
                **kwargs
            )
        else:
            raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.")
        
        # Load audio processor
        if "audio_processor" in config:
            # Create audio processor from config
            audio_config = config["audio_processor"]
            audio_processor = VibeVoiceTokenizerProcessor(
                sampling_rate=audio_config.get("sampling_rate", 24000),
                normalize_audio=audio_config.get("normalize_audio", True),
                target_dB_FS=audio_config.get("target_dB_FS", -25),
                eps=audio_config.get("eps", 1e-6),
            )
        else:
            # Create default audio processor
            audio_processor = VibeVoiceTokenizerProcessor()
        
        # Create and return the processor
        return cls(
            tokenizer=tokenizer,
            audio_processor=audio_processor,
            speech_tok_compress_ratio=speech_tok_compress_ratio,
            db_normalize=db_normalize,
        )
    
    def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
        """

        Save a processor to a directory, so that it can be re-loaded using the

        [`~VibeVoiceProcessor.from_pretrained`] class method.



        Args:

            save_directory (`str` or `os.PathLike`):

                Directory where the processor will be saved.

        """
        import os
        import json
        
        os.makedirs(save_directory, exist_ok=True)
        
        # Save processor configuration
        processor_config = {
            "processor_class": "VibeVoiceProcessor",
            "speech_tok_compress_ratio": self.speech_tok_compress_ratio,
            "db_normalize": self.db_normalize,
            "audio_processor": {
                "feature_extractor_type": "VibeVoiceTokenizerProcessor",
                "sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000),
                "normalize_audio": getattr(self.audio_processor, 'normalize_audio', True),
                "target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25),
                "eps": getattr(self.audio_processor, 'eps', 1e-6),
            }
        }
        
        config_path = os.path.join(save_directory, "preprocessor_config.json")
        with open(config_path, 'w') as f:
            json.dump(processor_config, f, indent=2)
        
        logger.info(f"Processor configuration saved in {config_path}")
    
    def __call__(

        self,

        text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,

        voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None,

        padding: Union[bool, str, PaddingStrategy] = True,

        truncation: Union[bool, str, TruncationStrategy] = False,

        max_length: Optional[int] = None,

        return_tensors: Optional[Union[str, TensorType]] = None,

        return_attention_mask: bool = True,

        **kwargs,

    ) -> BatchEncoding:
        """

        Main method to process one or more podcast scripts with optional voice samples.



        Args:

            text (`str`, `List[str]`):

                The input text(s) to process. Can be:

                - A single script string

                - A list of script strings for batch processing

                - A path to a .json or .txt file

                - A list of paths

            voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*):

                Voice samples for each script. Can be:

                - A list of samples for a single script

                - A list of lists for batch processing

            padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`):

                Whether to pad sequences to the same length

            truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`):

                Whether to truncate sequences

            max_length (`int`, *optional*):

                Maximum length of the returned sequences

            return_tensors (`str` or `TensorType`, *optional*):

                If set, will return tensors of a particular framework

            return_attention_mask (`bool`, defaults to `True`):

                Whether to return the attention mask



        Returns:

            `BatchEncoding`: A BatchEncoding with the following fields:

                - **input_ids** -- List of token id sequences or tensor

                - **attention_mask** -- List of attention masks or tensor

                - **speech_tensors** -- Padded speech inputs (if voice_samples provided)

                - **speech_masks** -- Speech masks (if voice_samples provided)

                - **speech_input_mask** -- Boolean masks indicating speech token positions

        """
        # Handle single vs batch input
        if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)):
            # Single input
            texts = [text]
            is_batched = False
        else:
            # Batch input
            texts = text
            is_batched = True
            
        # Handle voice samples
        if voice_samples is not None:
            if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))):
                # Single set of voice samples
                voice_samples_list = [voice_samples]
            else:
                # Batch of voice samples
                voice_samples_list = voice_samples
        else:
            voice_samples_list = [None] * len(texts)
        
        # Process each input
        all_encodings = []
        for text_input, voice_input in zip(texts, voice_samples_list):
            encoding = self._process_single(text_input, voice_input)
            all_encodings.append(encoding)
            
        # Combine batch
        batch_encoding = self._batch_encode(
            all_encodings,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=return_tensors,
            return_attention_mask=return_attention_mask,
        )
        
        return batch_encoding
    
    def _process_single(

        self,

        text: Union[str, TextInput],

        voice_samples: Optional[List[Union[str, np.ndarray]]] = None,

    ) -> Dict[str, Any]:
        """Process a single podcast script."""
        # Determine if text is a file path or direct script
        script = None
        if isinstance(text, str):
            # Check if it's a file path
            if text.endswith('.json') and os.path.exists(text):
                script = self._convert_json_to_script(text)
            elif text.endswith('.txt') and os.path.exists(text):
                script = self._convert_text_to_script(text)
            else:
                # Assume it's the script content directly
                script = text
        
        if script is None:
            raise ValueError(f"Could not process input text: {text}")
        
        # Parse the script
        parsed_lines = self._parse_script(script)
        all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines))
        
        # Create system prompt
        # system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False)
        system_tokens = self.tokenizer.encode(self.system_prompt)
        
        # Process voice samples if provided
        if voice_samples:
            voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)])
        else:
            voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], []
        
        # Build full token sequence
        full_tokens = system_tokens + voice_tokens
        speech_input_mask = [False] * len(system_tokens) + voice_speech_masks
        
        # Add text input section
        full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False)
        speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False))
        
        for speaker_id, speaker_text in parsed_lines:
            speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False)
            full_tokens += speaker_text_tokens
            speech_input_mask += [False] * len(speaker_text_tokens)
        
        # Add speech output section
        full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id]
        speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1)
        
        return {
            "input_ids": full_tokens,
            "speech_inputs": voice_speech_inputs if voice_speech_inputs else None,
            "speech_input_mask": speech_input_mask,
            "parsed_script": parsed_lines,
            "all_speakers": all_speakers,
        }
    
    def _batch_encode(

        self,

        encodings: List[Dict[str, Any]],

        padding: Union[bool, str, PaddingStrategy] = True,

        truncation: Union[bool, str, TruncationStrategy] = False,

        max_length: Optional[int] = None,

        return_tensors: Optional[Union[str, TensorType]] = None,

        return_attention_mask: bool = True,

    ) -> BatchEncoding:
        """Combine multiple encodings into a batch with padding."""
        # Extract input_ids and create attention_mask
        input_ids_list = [enc["input_ids"] for enc in encodings]
        speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings]
        
        # Determine padding strategy
        if isinstance(padding, bool):
            padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD
        elif isinstance(padding, str):
            padding_strategy = PaddingStrategy(padding)
        else:
            padding_strategy = padding
            
        # Apply padding to input_ids
        if padding_strategy != PaddingStrategy.DO_NOT_PAD:
            if padding_strategy == PaddingStrategy.LONGEST:
                max_len = max(len(ids) for ids in input_ids_list)
            elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
                max_len = max_length
            else:
                max_len = max(len(ids) for ids in input_ids_list)
                
            # Pad sequences
            padded_input_ids = []
            attention_masks = []
            padded_speech_input_masks = []
            
            for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list):
                # Truncate if needed
                if truncation and len(input_ids) > max_len:
                    input_ids = input_ids[:max_len]
                    speech_mask = speech_mask[:max_len]
                    
                # Pad
                padding_length = max_len - len(input_ids)
                # padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
                padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids
                attention_mask = [0] * padding_length + [1] * len(input_ids)
                padded_speech_mask = [False] * padding_length + speech_mask
                
                padded_input_ids.append(padded_ids)
                attention_masks.append(attention_mask)
                padded_speech_input_masks.append(padded_speech_mask)
                
            input_ids_list = padded_input_ids
            speech_input_masks_list = padded_speech_input_masks
        else:
            # No padding, just create attention masks
            attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None
            
        # Process speech inputs
        all_speech_inputs = []
        has_speech = False
        for enc in encodings:
            if enc["speech_inputs"] is not None:
                all_speech_inputs.extend(enc["speech_inputs"])
                has_speech = True
                
        # Prepare batch encoding
        batch_encoding = BatchEncoding()
        
        # Handle tensor conversion
        if return_tensors is not None:
            batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long)
            if return_attention_mask and attention_masks is not None:
                batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long)
            batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool)
        else:
            batch_encoding["input_ids"] = input_ids_list
            if return_attention_mask and attention_masks is not None:
                batch_encoding["attention_mask"] = attention_masks
            batch_encoding["speech_input_mask"] = speech_input_masks_list
            
        # Process speech tensors if present
        if has_speech:
            speech_dict = self.prepare_speech_inputs(
                all_speech_inputs,
                return_tensors=return_tensors,
            )
            batch_encoding["speech_tensors"] = speech_dict["padded_speeches"]
            batch_encoding["speech_masks"] = speech_dict["speech_masks"]
        else:
            batch_encoding["speech_tensors"] = None
            batch_encoding["speech_masks"] = None
            
        # Add metadata
        batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings]
        batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings]
        
        return batch_encoding

    def _create_voice_prompt(

        self, 

        speaker_samples: List[Union[str, np.ndarray]]

    ) -> Tuple[List[int], List[np.ndarray], List[bool]]:
        """

        Create voice prompt tokens and process audio samples.

        

        Returns:

            tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks)

        """
        vae_token_id = self.tokenizer.speech_diffusion_id
        
        voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False)
        voice_speech_inputs = []
        voice_speech_masks = [False] * len(voice_full_tokens)
        
        for speaker_id, speaker_audio in enumerate(speaker_samples):
            prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False)
            
            # Process audio
            if isinstance(speaker_audio, str):
                # Load audio from file
                wav = self.audio_processor._load_audio_from_path(speaker_audio)
            else:
                wav = np.array(speaker_audio, dtype=np.float32)
            
            # Apply normalization if needed
            if self.db_normalize and self.audio_normalizer:
                wav = self.audio_normalizer(wav)
            
            # Calculate token length based on compression ratio
            # if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'):
            #     vae_tok_len = wav.shape[0]
            # else:
            vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio)
            
            # Build tokens and masks
            speaker_tokens = (prefix_tokens + 
                            [self.tokenizer.speech_start_id] + 
                            [vae_token_id] * vae_tok_len + 
                            [self.tokenizer.speech_end_id] + 
                            self.tokenizer.encode('\n', add_special_tokens=False))
            
            vae_input_mask = ([False] * len(prefix_tokens) + 
                            [False] + 
                            [True] * vae_tok_len + 
                            [False] + 
                            [False])
            
            voice_full_tokens.extend(speaker_tokens)
            voice_speech_masks.extend(vae_input_mask)
            voice_speech_inputs.append(wav)
            
        return voice_full_tokens, voice_speech_inputs, voice_speech_masks

    def prepare_speech_inputs(

        self,

        speech_inputs: List[np.ndarray],

        return_tensors: Optional[Union[str, TensorType]] = None,

        device: Optional[Union[str, torch.device]] = None,

        dtype: Optional[torch.dtype] = None,

    ) -> Dict[str, Any]:
        """

        Prepare speech inputs for model consumption.

        

        Args:

            speech_inputs: List of speech arrays

            return_tensors: Output tensor type

            device: Device to place tensors on

            dtype: Data type for tensors

            

        Returns:

            Dictionary with padded_speeches and speech_masks

        """
        if not speech_inputs:
            return {"padded_speeches": None, "speech_masks": None}
        
        # Calculate sequence lengths
        vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs]
        # vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs]
        max_speech_length = max(s.shape[0] for s in speech_inputs)
        
        # Pad speeches
        if speech_inputs[0].ndim == 1:
            padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32)
        else:
            padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32)
        speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_)
        
        for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)):
            padded_speeches[i, :len(speech)] = speech
            speech_masks[i, :vae_tok_length] = True
        
        result = {
            "padded_speeches": padded_speeches,
            "speech_masks": speech_masks,
        }
        
        # Convert to tensors if requested
        if return_tensors == "pt":
            result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32)
            result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool)
            
        return result
        
    def _convert_json_to_script(self, json_file: str) -> str:
        """

        Convert JSON format to script format.

        Expected JSON format:

        [

            {"speaker": "1", "text": "Hello everyone..."},

            {"speaker": "2", "text": "Great to be here..."}

        ]

        """
        import json
        
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        if not isinstance(data, list):
            raise ValueError("JSON file must contain a list of speaker entries")
        
        script_lines = []
        for item in data:
            if not isinstance(item, dict):
                logger.warning(f"Skipping non-dict entry: {item}")
                continue
                
            speaker = item.get('speaker')
            text = item.get('text')
            
            if speaker is None or text is None:
                logger.warning(f"Skipping entry missing speaker or text: {item}")
                continue
            
            # Ensure speaker ID is valid
            try:
                speaker_id = int(speaker)
            except (ValueError, TypeError):
                logger.warning(f"Invalid speaker ID: {speaker}, skipping entry")
                continue
            
            # Clean up text
            text = text.strip()
            if text:
                script_lines.append(f"Speaker {speaker_id}: {text}")
        
        if not script_lines:
            raise ValueError("No valid entries found in JSON file")
            
        return "\n".join(script_lines)

    def _convert_text_to_script(self, text_file: str) -> str:
        """

        Convert text file to script format.

        Handles multiple formats:

        1. Already formatted as "Speaker X: text"

        2. Plain text (assigns to Speaker 1)

        

        Handles edge cases like multiple colons in a line.

        """
        with open(text_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        script_lines = []
        current_speaker = 1
        
        for line in lines:
            line = line.strip()
            if not line:
                continue
            
            # Try to parse as "Speaker X: text" format
            # Use regex to be more robust
            speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE)
            
            if speaker_match:
                speaker_id = int(speaker_match.group(1))
                text = speaker_match.group(2).strip()
                if text:
                    script_lines.append(f"Speaker {speaker_id}: {text}")
            else:
                # Treat as plain text - assign to current speaker
                script_lines.append(f"Speaker {current_speaker}: {line}")
        
        if not script_lines:
            raise ValueError("No valid content found in text file")
            
        return "\n".join(script_lines)

    def _parse_script(self, script: str) -> List[Tuple[int, str]]:
        """Parse script into list of (speaker_id, text) tuples."""
        lines = script.strip().split("\n")
        parsed_lines = []
        speaker_ids = []
                
        # First pass: parse all lines and collect speaker IDs
        for line in lines:
            if not line.strip():
                continue
                
            # Use regex to handle edge cases like multiple colons
            match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE)
            
            if match:
                speaker_id = int(match.group(1))
                text = ' ' + match.group(2).strip()
                parsed_lines.append((speaker_id, text))
                speaker_ids.append(speaker_id)
            else:
                logger.warning(f"Could not parse line: '{line}'")
        
        if not parsed_lines:
            raise ValueError("No valid speaker lines found in script")
        
        # Check if we need to normalize speaker IDs (only if all are > 0)
        min_speaker_id = min(speaker_ids)
        if min_speaker_id > 0:
            # Normalize to start from 0
            normalized_lines = []
            for speaker_id, text in parsed_lines:
                normalized_lines.append((speaker_id - 1, text))
            return normalized_lines
        else:
            # Keep original IDs
            return parsed_lines

    def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding:
        """Merge text and audio inputs into a single BatchEncoding."""
        # Start with text inputs
        merged = BatchEncoding(text_inputs)
        
        # Add audio-specific fields
        if "audio" in audio_inputs:
            merged["speech_inputs"] = audio_inputs["audio"]
        if "streaming" in audio_inputs:
            merged["streaming"] = audio_inputs["streaming"]
            
        return merged

    def batch_decode(self, *args, **kwargs):
        """

        This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`].

        Please refer to the docstring of this method for more information.

        """
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """

        This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`].

        Please refer to the docstring of this method for more information.

        """
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        """

        Return the list of inputs accepted by the model.

        """
        tokenizer_input_names = self.tokenizer.model_input_names
        audio_processor_input_names = self.audio_processor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"]))

    def save_audio(self, 

        audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]],

        output_path: str = "output.wav",

        sampling_rate: Optional[int] = None,

        normalize: bool = False,

        batch_prefix: str = "audio_",

    ) -> str:
        """

        Save audio data to a file.

        Args:

            audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]):

                The audio data to save. Can be a single tensor/array or a list of them.

            output_path (str, optional): Path to save the audio file. Defaults to "output.wav".

            sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default.

            normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False.

            batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_".

        Returns:

            str: The path to the saved audio file.

        """
        return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix)
    
__all__ = [
    "VibeVoiceProcessor",
]