File size: 39,024 Bytes
c29df8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
# coding=utf-8
# Copyright 2025 OpenMOSS and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for MOSS-TTSD.
"""

from __future__ import annotations

import math
import os
import re
from dataclasses import asdict, dataclass
from typing import Any, Callable, Optional, Union

import numpy as np

from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from transformers.tokenization_utils_base import BatchEncoding
from transformers.utils import is_torch_available, is_torchaudio_available
from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModel
#from transformers.models.xy_tokenizer.modeling_xy_tokenizer import XYTokenizer


if is_torch_available():
    import torch

if is_torchaudio_available():
    import torchaudio


class MossTTSDProcessorKwargs(ProcessingKwargs, total=False):
    """
    Arguments for configuring MOSS-TTSD processing operations.

    Inherits from ProcessingKwargs and provides structured configuration for text and audio processing.
    """

    _defaults = {
        "text_kwargs": {
            "pad_token_id": 0,  # Fallback pad token ID, actual value comes from tokenizer.pad_token_id
        },
        "audio_kwargs": {
            "max_channels": 8,  # Maximum number of quantization channels
            "audio_pad_token_id": 1024,  # Padding token ID for non-text channels
            "silence_duration": 0.0,  # Duration of silence to append for encoder segmentation
            "input_sample_rate": 16000,  # Input audio sampling rate (fallback, inferred from audio_tokenizer.config)
            "encoder_downsample_rate": 320,  # Encoder downsampling rate (fallback, inferred from audio_tokenizer.config)
            "speech_token_range": [151665, 152689],  # Token range for speech tokens (first codebook offset mapping)
            "audio_bos_token": "<|begin_of_speech|>",
            "audio_eos_token": "<|end_of_speech|>",
        },
        "common_kwargs": {
            "return_tensors": "pt",
            "padding": True,
            "use_normalize": False,
        },
    }


@dataclass
class MossTTSDChatSample:
    """
    Intermediate representation of a single sample with T×C grid layout and metadata.

    Args:
        input_ids_2d (`torch.LongTensor`):
            Shape (T, C) tensor where column 0 contains text tokens and columns 1..C-1 contain
            quantized audio codebooks (or padding token 1024 for empty slots).
        label_ids_2d (`torch.LongTensor`, *optional*):
            Optional label tensor for training, same shape as input_ids_2d.
        meta (`dict`):
            Dictionary containing metadata for debugging and tracking purposes.
    """

    input_ids_2d: "torch.LongTensor"
    label_ids_2d: Optional["torch.LongTensor"]
    meta: dict

@dataclass
class MossTTSDBatchInput:
    """
    Batched input tensors for MOSS-TTSD model.

    Args:
        input_ids (`torch.LongTensor`):
            Shape (B, T, C) tensor containing batched input token IDs.
        attention_mask (`torch.LongTensor`):
            Shape (B, T) tensor containing attention mask for valid tokens.
        labels (`torch.LongTensor`, *optional*):
            Optional shape (B, T, C) tensor containing label token IDs for training.
    """

    input_ids: "torch.LongTensor"
    attention_mask: "torch.LongTensor"
    labels: Optional["torch.LongTensor"]


@dataclass
class MossTTSDResponse:
    """
    Unified response container for MOSS-TTSD inference outputs.

    Args:
        audio (`np.ndarray`, *optional*):
            Optional numpy array containing generated audio waveform.
        generated_text (`str`, *optional*, defaults to `""`):
            String containing generated text output.
        sampling_rate (`int`, *optional*):
            Optional integer specifying the sampling rate of the generated audio.
    """

    audio: Optional[np.ndarray] = None
    generated_text: str = ""
    sampling_rate: Optional[int] = None


class MossTTSDSampleProcessor:
    """
    Sample-level processor for MOSS-TTSD that handles individual sample processing without batch padding.

    This class handles per-sample processing logic:
    - Parses JSONL items (text/prompt_text/prompt_audio)
    - Optional text normalization
    - Audio loading/resampling/merging, feature extraction and encoding
    - Generates T×C grid and performs multi-channel shifting

    Args:
        tokenizer (`AutoTokenizer`):
            The text tokenizer for encoding text tokens.
        feature_extractor (`AutoFeatureExtractor`, *optional*):
            Optional feature extractor for audio preprocessing.
        audio_tokenizer (`AutoModel`, *optional*):
            Optional audio tokenizer for audio encoding/decoding.
        chat_template (`str`, *optional*):
            Optional chat template string for conversation formatting.
        speech_token_range (`List[int]`):
            List of [start, end] token IDs for speech token mapping.
        audio_bos_token (`str`):
            Beginning of speech token string.
        audio_eos_token (`str`):
            End of speech token string.
        audio_pad_token_id (`int`):
            Padding token ID for audio channels.
        max_channels (`int`):
            Maximum number of quantization channels.
        input_sample_rate (`int`):
            Target sample rate for input audio.
        encoder_downsample_rate (`int`):
            Downsampling rate of the audio encoder.
    """

    def __init__(
        self,
        tokenizer,
        feature_extractor: Optional = None,
        audio_tokenizer: Optional = None,
        *,
        chat_template: Optional[str],
        speech_token_range: list[int],
        audio_bos_token: str,
        audio_eos_token: str,
        audio_pad_token_id: int,
        max_channels: int,
        input_sample_rate: int,
        encoder_downsample_rate: int,
    ) -> None:
        self.tokenizer = tokenizer
        self.feature_extractor = feature_extractor
        self.audio_tokenizer = audio_tokenizer
        self.chat_template = chat_template
        self.speech_token_range = speech_token_range
        self.audio_bos_token = audio_bos_token
        self.audio_eos_token = audio_eos_token
        self.audio_pad_token_id = audio_pad_token_id
        self.max_channels = max_channels
        self.input_sample_rate = input_sample_rate
        self.encoder_downsample_rate = encoder_downsample_rate

    def prepare_sample(
        self,
        item: dict[str, Any],
        *,
        apply_chat_template: Callable[[str, dict], str],
        use_normalize: bool = False,
        silence_duration: float = 0.0,
        **kwargs,
    ) -> MossTTSDChatSample:
        """
        Prepare a single sample from JSONL item into MossTTSDChatSample format.

        Args:
            item (`dict`):
                Dictionary containing the input data (text, prompt_audio, etc.).
            apply_chat_template (`callable`):
                Function to apply chat template formatting.
            use_normalize (`bool`, *optional*, defaults to `False`):
                Whether to apply text normalization.
            silence_duration (`float`, *optional*, defaults to `0.0`):
                Duration of silence to append to audio for encoder segmentation.
            **kwargs:
                Additional keyword arguments passed to chat template.

        Returns:
            `MossTTSDChatSample`: Processed sample with 2D input tensor and metadata.
        """
        processed = self._process_jsonl_item(item)
        system_prompt = item.get("system_prompt")
        if isinstance(system_prompt, str):
            kwargs["system_prompt"] = system_prompt

        full_text = (processed["prompt_text"] or "") + processed["text"]
        original_full_text = full_text
        if use_normalize:
            full_text = self._normalize_text(full_text)
        final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>")

        # Load and resample audio (may be None)
        wav = self._process_audio_data(processed["prompt_audio"], target_sample_rate=self.input_sample_rate)

        # Assemble into grid (T, C)
        inputs_2d = self._build_inputs(
            text=final_text,
            audio_data=wav,
            apply_chat_template=apply_chat_template,
            silence_duration=silence_duration,
            **kwargs,
        )
        inputs_2d = self._shift_inputs(inputs_2d, pad_token_id=self.tokenizer.pad_token_id, max_channels=self.max_channels)

        meta = {
            "original_text": original_full_text,
            "normalized_text": self._normalize_text(original_full_text) if use_normalize else None,
            "final_text": final_text,
            "use_normalize": use_normalize,
        }
        ids_t = torch.tensor(inputs_2d, dtype=torch.long)
        return MossTTSDChatSample(input_ids_2d=ids_t, label_ids_2d=None, meta=meta)

    def collate(
        self,
        samples: list[MossTTSDChatSample],
        *,
        pad_token_id: int,
        audio_pad_token_id: int,
    ) -> MossTTSDBatchInput:
        """
        Collate multiple samples into a batch with proper padding.

        Args:
            samples (`List[MossTTSDChatSample]`):
                List of MossTTSDChatSample objects to collate.
            pad_token_id (`int`):
                Padding token ID for text tokens.
            audio_pad_token_id (`int`):
                Padding token ID for audio tokens.

        Returns:
            `MossTTSDBatchInput`: Batched input with padded tensors.
        """
        assert is_torch_available(), "PyTorch is required for collation."
        ids_list = [s.input_ids_2d for s in samples]
        labels_list = [s.label_ids_2d for s in samples]

        C = ids_list[0].shape[1]
        max_len = max(x.shape[0] for x in ids_list)
        padded_ids, padded_labels, padded_attn = [], [], []

        for ids, labels in zip(ids_list, labels_list):
            pad_len = max_len - ids.shape[0]
            pad_grid = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long)
            pad_grid[:, 0] = pad_token_id  # Text column uses tokenizer pad
            ids_padded = torch.cat([pad_grid, ids], dim=0)
            padded_ids.append(ids_padded)

            attn = torch.ones(ids.shape[0], dtype=torch.long)
            a_pad = torch.zeros(pad_len, dtype=torch.long)
            padded_attn.append(torch.cat([a_pad, attn], dim=0))

            if labels is None:
                padded_labels.append(None)
            else:
                lab_pad = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long)
                lab_pad[:, 0] = -100  # Text labels are ignored by default
                padded_labels.append(torch.cat([lab_pad, labels], dim=0))

        input_ids = torch.stack(padded_ids)  # (B, T, C)
        attention_mask = torch.stack(padded_attn)  # (B, T)
        labels = torch.stack([l if l is not None else torch.full_like(input_ids[0], -100) for l in padded_labels]) \
                 if any(l is not None for l in padded_labels) else None

        return MossTTSDBatchInput(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    @staticmethod
    def _process_jsonl_item(item: dict[str, Any]) -> dict[str, Any]:
        """
        Process a JSONL item to extract text and audio data.

        Supports both single-speaker and multi-speaker formats:
        - Single: {"prompt_audio": path, "prompt_text": text}
        - Multi: {"prompt_audio_speaker1": path1, "prompt_text_speaker1": text1, ...}

        Args:
            item: Dictionary containing the JSONL item data.

        Returns:
            Dictionary with extracted "text", "prompt_text", and "prompt_audio" fields.
        """
        base_path = item.get("base_path", "")
        text = item.get("text", "")

        prompt_audio = None
        prompt_text = ""

        if "prompt_audio" in item and "prompt_text" in item:
            pa = item["prompt_audio"]
            if pa:
                prompt_audio = os.path.join(base_path, pa) if isinstance(pa, str) and base_path else pa
                prompt_text = item.get("prompt_text", "")
        else:
            pa1, pt1 = item.get("prompt_audio_speaker1", ""), item.get("prompt_text_speaker1", "")
            pa2, pt2 = item.get("prompt_audio_speaker2", ""), item.get("prompt_text_speaker2", "")
            has1 = (isinstance(pa1, str) and pa1) or isinstance(pa1, tuple)
            has2 = (isinstance(pa2, str) and pa2) or isinstance(pa2, tuple)
            if has1 or has2:
                spk1 = os.path.join(base_path, pa1) if isinstance(pa1, str) and base_path and pa1 else pa1
                spk2 = os.path.join(base_path, pa2) if isinstance(pa2, str) and base_path and pa2 else pa2
                prompt_audio = {"speaker1": spk1, "speaker2": spk2}
            tmp = ""
            if pt1:
                tmp += f"[S1]{pt1}"
            if pt2:
                tmp += f"[S2]{pt2}"
            prompt_text = tmp.strip()

        return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}

    @staticmethod
    def _normalize_text(text: str) -> str:
        """
        Normalize text by applying various transformations for TTS processing.

        Performs speaker tag conversion, punctuation normalization, laughter conversion,
        and other text cleaning operations suitable for speech synthesis.

        Args:
            text: Input text string to normalize.

        Returns:
            Normalized text string.
        """
        text = re.sub(r"\[(\d+)\]", r"[S\1]", text)
        remove_chars = '【】《》()『』「」"-""~~'
        text = re.sub(r"\[(?!S\d+\])([^\]]*)\]", r"\1", text)
        segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " "))
        out = []
        for seg in segments:
            seg = seg.strip()
            if not seg:
                continue
            m = re.match(r"^(\[S\d+\])\s*(.*)", seg)
            tag, content = m.groups() if m else ("", seg)
            content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
            content = re.sub(r"哈{2,}", "(笑)", content)
            content = re.sub(r"\b(ha(\s*ha)+)\b", "(laughs)", content, flags=re.IGNORECASE)
            content = content.replace("——", ",").replace("……", ",")
            trans = str.maketrans({"!": ",", "!": ",", ";": ",", ";": ",", ":": ",", ":": ",", "、": ",", "?": ",", "?": ","})
            content = content.translate(trans).strip()
            if len(content) > 1:
                last = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
                body = content[:-1].replace("。", ",")
                content = body + last
            out.append(f"{tag}{content}".strip())
        return "".join(out)

    @staticmethod
    def _load_single_audio(audio_input: Union[str, tuple["torch.Tensor", int]]):
        """
        Load audio from file path or tensor tuple.

        Args:
            audio_input: Either a file path string or a tuple of (tensor, sample_rate).

        Returns:
            Tuple of (audio_tensor, sample_rate).

        Raises:
            ValueError: If audio input format is unsupported.
        """
        if isinstance(audio_input, tuple) and len(audio_input) == 2:
            return audio_input
        if isinstance(audio_input, str):
            try:
                return torchaudio.load(audio_input)
            except Exception:
                import soundfile as sf  # type: ignore
                data, sr = sf.read(audio_input, always_2d=True)
                data_t = torch.from_numpy(np.transpose(data))  # (C, T)
                return data_t, int(sr)
        raise ValueError(f"Unsupported audio input format: {type(audio_input)}")

    @staticmethod
    def _resample(audio: "torch.Tensor", sr: int, target_sr: int) -> tuple["torch.Tensor", int]:
        """
        Resample audio to target sample rate and convert to mono if needed.

        Args:
            audio: Input audio tensor with shape (channels, time).
            sr: Current sample rate.
            target_sr: Target sample rate.

        Returns:
            Tuple of (resampled_audio, target_sr) where audio is mono with shape (1, time).
        """
        if sr != target_sr:
            audio = torchaudio.functional.resample(audio, sr, target_sr)
        if audio.shape[0] > 1:
            audio = audio.mean(dim=0, keepdim=True)
        if audio.ndim == 1:
            audio = audio.unsqueeze(0)
        return audio, target_sr

    @classmethod
    def _load_audio_data(
        cls, audio_input: Union[str, tuple["torch.Tensor", int]], target_sample_rate: int
    ) -> tuple["torch.Tensor", int]:
        """
        Load and resample audio data to target sample rate.

        Args:
            audio_input: Audio file path or tensor tuple.
            target_sample_rate: Target sample rate for resampling.

        Returns:
            Tuple of (audio_tensor, target_sample_rate).
        """
        audio, sr = cls._load_single_audio(audio_input)
        return cls._resample(audio, sr, target_sample_rate)

    @classmethod
    def _merge_speaker_audios(
        cls,
        wav1: Union[str, tuple["torch.Tensor", int]],
        wav2: Union[str, tuple["torch.Tensor", int]],
        target_sample_rate: int,
    ) -> "torch.Tensor":
        """
        Merge two speaker audio inputs by concatenation.

        Args:
            wav1: Audio input for speaker 1.
            wav2: Audio input for speaker 2.
            target_sample_rate: Target sample rate for both audio inputs.

        Returns:
            Concatenated audio tensor.
        """
        a1, _ = cls._load_audio_data(wav1, target_sample_rate)
        a2, _ = cls._load_audio_data(wav2, target_sample_rate)
        return torch.cat([a1, a2], dim=1)

    @classmethod
    def _process_audio_data(
        cls, prompt_audio: Optional[Union[str, dict[str, Any], tuple["torch.Tensor", int]]], target_sample_rate: int
    ) -> Optional["torch.Tensor"]:
        """
        Process audio data from various input formats.

        Handles single audio files, multi-speaker audio dictionaries, or None input.

        Args:
            prompt_audio: Audio input in various formats (path, dict, tensor tuple, or None).
            target_sample_rate: Target sample rate for processing.

        Returns:
            Processed audio tensor or None if no audio provided.
        """
        if prompt_audio is None:
            return None
        if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio:
            return cls._merge_speaker_audios(prompt_audio["speaker1"], prompt_audio["speaker2"], target_sample_rate)
        wav, _ = cls._load_audio_data(prompt_audio, target_sample_rate)
        return wav

    def _build_inputs(
        self,
        text: str,
        audio_data: Optional["torch.Tensor"],
        apply_chat_template: Callable[[str, dict], str],
        silence_duration: float,
        **kwargs,
    ) -> np.ndarray:
        """
        Build input grid from text and optional audio data.

        Creates a TxC grid where column 0 contains text tokens and columns 1..C-1 contain
        quantized audio codebook tokens. Audio tokens are mapped to speech token range.

        Args:
            text: Input text string to process.
            audio_data: Optional audio tensor with shape (channels, time).
            apply_chat_template: Function to apply chat template formatting.
            silence_duration: Duration of silence to append for encoder segmentation.
            **kwargs: Additional arguments for chat template.

        Returns:
            NumPy array with shape (T, max_channels) containing the input grid.
        """
        assert isinstance(text, str), "text must be a string"
        prompt = apply_chat_template(text, kwargs)

        text_ids = np.array(self.tokenizer.encode(prompt, add_special_tokens=False))
        grid = np.full((text_ids.shape[0], self.max_channels), self.audio_pad_token_id, dtype=np.int64)
        grid[:, 0] = text_ids

        if audio_data is not None:
            silence_samples = int(max(0.0, silence_duration) * self.input_sample_rate)
            silence = torch.zeros(audio_data.shape[0], silence_samples, device=audio_data.device)
            wav = torch.cat([audio_data, silence], dim=1)

            feat = self.feature_extractor(
                wav, sampling_rate=self.input_sample_rate, return_attention_mask=True, return_tensors="pt"
            )
            with torch.no_grad():
                enc = self.audio_tokenizer.encode(feat)
                # (time, codebooks)
                audio_codes = enc["audio_codes"][:, 0].permute(1, 0).cpu().numpy()
            # Map first codebook to speech token range
            audio_codes[:, 0] = audio_codes[:, 0] + self.speech_token_range[0]
            grid = np.concatenate([grid, audio_codes], axis=0)

            # Trim silence tokens at the end based on encoder downsampling
            silence_tokens = silence_duration * self.input_sample_rate / self.encoder_downsample_rate
            cut = math.floor(silence_tokens / 10) * 10
            if cut > 0:
                grid = grid[:-cut]

        return grid

    @staticmethod
    def _shift_inputs(input_ids: np.ndarray, pad_token_id: int, max_channels: int) -> np.ndarray:
        """
        Convert (T, C) grid to time-shifted multi-channel layout (preserving original implementation logic).

        Creates a shifted layout where new_len = T + C - 1, with column j shifted backwards by j positions.
        This enables the model to process multiple codebook channels with temporal alignment.

        Args:
            input_ids: Input grid with shape (T, C).
            pad_token_id: Padding token ID for text tokens.
            max_channels: Maximum number of channels.

        Returns:
            Shifted array with shape (T + max_channels - 1, max_channels).
        """
        T, _ = input_ids.shape
        new_len = T + max_channels - 1
        shifted = np.full((new_len, max_channels), fill_value=1024, dtype=np.int64)
        shifted[:, 0] = np.full(new_len, pad_token_id, dtype=np.int64)
        for j in range(max_channels):
            shifted[j : (T + j), j] = input_ids[:, j]
        return shifted


class MossTTSDProcessor(ProcessorMixin):
    r"""
    Constructs a MOSS-TTSD processor which wraps a tokenizer, feature extractor, and audio tokenizer into a single
    processor. It provides unified text-speech processing capabilities while maintaining backward compatibility with
    previous API versions.

    [`MossTTSDProcessor`] offers all the functionalities of [`AutoTokenizer`], [`AutoFeatureExtractor`] and
    [`XYTokenizer`]. See the [`~MossTTSDProcessor.__call__`] and [`~MossTTSDProcessor.decode`] for more information.

    Args:
        tokenizer ([`AutoTokenizer`]):
            An instance of [`AutoTokenizer`]. The tokenizer is a required input.
        feature_extractor ([`AutoFeatureExtractor`]):
            An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input.
        audio_tokenizer ([`XYTokenizer`]):
            An instance of [`XYTokenizer`]. The audio tokenizer is a required input.
        chat_template (`str`, *optional*):
            A template string for chat formatting when combining text and audio interactions.
        speech_token_range (`List[int]`, *optional*, defaults to `[151665, 152689]`):
            Token range [start, end] for mapping speech tokens.
        audio_bos_token (`str`, *optional*, defaults to `"<|begin_of_speech|>"`):
            Beginning of speech token string.
        audio_eos_token (`str`, *optional*, defaults to `"<|end_of_speech|>"`):
            End of speech token string.
        audio_pad_token_id (`int`, *optional*, defaults to `1024`):
            Padding token ID for audio channels.
    """
    feature_extractor_class = "AutoFeatureExtractor"
    tokenizer_class = "AutoTokenizer"
    audio_tokenizer_class = "PreTrainedModel"

    def __init__(
        self,
        tokenizer,
        feature_extractor,
        audio_tokenizer,
        chat_template: Optional[str] = None,
        speech_token_range: Optional[list[int]] = None,
        audio_bos_token: str = "<|begin_of_speech|>",
        audio_eos_token: str = "<|end_of_speech|>",
        audio_pad_token_id: int = 1024,
        **kwargs,
    ) -> None:
        super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer, **kwargs)

        self.max_channels = (audio_tokenizer.quantizer.num_quantizers if audio_tokenizer else None) or 8
        self.input_sample_rate = (getattr(audio_tokenizer, "config", None).input_sample_rate if audio_tokenizer else None) or 16000
        self.output_sample_rate = (getattr(audio_tokenizer, "config", None).output_sample_rate if audio_tokenizer else None) or 16000
        self.encoder_downsample_rate = (getattr(audio_tokenizer, "config", None).encoder_downsample_rate if audio_tokenizer else None) or 320

        # Use tokenizer's built-in chat template as primary
        self.chat_template = getattr(tokenizer, "chat_template", None) or chat_template

        # Read speech token range from tokenizer with fallback
        self.speech_token_range = (
            getattr(tokenizer, "speech_token_range", None) or speech_token_range or [151665, 152689]
        )
        self.audio_bos_token = getattr(tokenizer, "audio_bos_token", None) or audio_bos_token
        self.audio_eos_token = getattr(tokenizer, "audio_eos_token", None) or audio_eos_token
        self.audio_pad_token_id = getattr(tokenizer, "audio_pad_token_id", None) or audio_pad_token_id

        # Sample-level processor
        self.sample_processor = MossTTSDSampleProcessor(
            tokenizer=self.tokenizer,
            feature_extractor=self.feature_extractor,
            audio_tokenizer=self.audio_tokenizer,
            chat_template=self.chat_template,
            speech_token_range=self.speech_token_range,
            audio_bos_token=self.audio_bos_token,
            audio_eos_token=self.audio_eos_token,
            audio_pad_token_id=self.audio_pad_token_id,
            max_channels=self.max_channels,
            input_sample_rate=self.input_sample_rate,
            encoder_downsample_rate=self.encoder_downsample_rate,
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], trust_remote_code=True, **kwargs):
        """
        Instantiate a processor from a pretrained model.

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                The name of or path to the pretrained model.
            **kwargs:
                Additional keyword arguments passed to the respective component loaders.

        Returns:
            [`MossTTSDProcessor`]: A new instance of the processor.
        """
        kwargs.pop("_from_auto")
        audio_tokenizer_path = kwargs.pop("codec_path", os.path.join(pretrained_model_name_or_path, "XY_Tokenizer"))
        assert isinstance(audio_tokenizer_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_path)}"

        tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
        feature_extractor = AutoFeatureExtractor.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs)
        audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs)

        return cls(
            tokenizer=tokenizer,
            feature_extractor=feature_extractor,
            audio_tokenizer=audio_tokenizer,
            **kwargs,
        )

    @classmethod
    def get_processor_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        proc_dict, rest = super().get_processor_dict(pretrained_model_name_or_path, **kwargs)
        if "audio_tokenizer" in rest:
            proc_dict["audio_tokenizer"] = rest.pop("audio_tokenizer")
        for key in ("speech_token_range", "audio_bos_token", "audio_eos_token", "audio_pad_token_id"):
            if key in rest:
                proc_dict[key] = rest.pop(key)
        return proc_dict, rest

    def __call__(
        self,
        data: Union[dict[str, Any], list[dict[str, Any]]],
        **kwargs: Unpack[MossTTSDProcessorKwargs],
    ) -> BatchEncoding:
        """
        Main method to prepare inputs for the model from structured data.

        This method forwards the `data` and `kwargs` arguments to prepare inputs for MOSS-TTSD model. Please refer to the
        docstring of the respective methods for more information.

        Args:
            data (`dict` or `list[dict]`):
                Single dictionary or list of dictionaries containing input data. Expected keys include 'text',
                'prompt_text', 'prompt_audio', etc.
            **kwargs (`MossTTSDProcessorKwargs`):
                Additional processing arguments.

        Returns:
            [`BatchEncoding`]: Processed inputs ready for model consumption.
        """
        if isinstance(data, dict):
            data = [data]

        out_kwargs = self._merge_kwargs(MossTTSDProcessorKwargs, **kwargs)
        text_kwargs = out_kwargs["text_kwargs"]
        audio_kwargs = out_kwargs["audio_kwargs"]
        common_kwargs = out_kwargs["common_kwargs"]

        return_tensors = common_kwargs.get("return_tensors", "pt")
        padding = common_kwargs.get("padding", True)
        use_normalize = common_kwargs.get("use_normalize", False)

        pad_token_id = int(text_kwargs.get("pad_token_id", self.tokenizer.pad_token_id or 0))
        max_channels = int(audio_kwargs.get("max_channels", self.max_channels))
        audio_pad_token_id = int(audio_kwargs.get("audio_pad_token_id", self.audio_pad_token_id))
        silence_duration = float(audio_kwargs.get("silence_duration", 0.0))

        def _apply_chat_template(text: str, extra: dict) -> str:
            return self.apply_chat_template(conversation=None, text=text, **extra)

        samples: list[MossTTSDChatSample] = []
        for item in data:
            sample = self.sample_processor.prepare_sample(
                item,
                apply_chat_template=_apply_chat_template,
                use_normalize=use_normalize,
                silence_duration=silence_duration,
            )
            # Override with call-time max_channels (may differ from component initialization)
            if sample.input_ids_2d.shape[1] != max_channels:
                # Simplified: for clipping/extending channels, only pad/clip on the right side
                T, C = sample.input_ids_2d.shape
                if C > max_channels:
                    sample.input_ids_2d = sample.input_ids_2d[:, :max_channels]
                else:
                    pad = torch.full((T, max_channels - C), audio_pad_token_id, dtype=torch.long)
                    sample.input_ids_2d = torch.cat([sample.input_ids_2d, pad], dim=1)
            samples.append(sample)

        if not padding:
            raise NotImplementedError("Unpadded batches are not supported yet.")

        batch = self.sample_processor.collate(
            samples,
            pad_token_id=pad_token_id,
            audio_pad_token_id=audio_pad_token_id,
        )
        # Align with HiggsAudioProcessor: explicit dict -> BatchEncoding/Feature
        inputs = asdict(batch)
        inputs = {k: v for k, v in inputs.items() if v is not None}
        return BatchEncoding(inputs, tensor_type=return_tensors)

    def shifting_outputs(
        self,
        output_ids: "torch.Tensor",
        speech_token_range: list[int],
        max_channels: int = 8,
    ) -> "torch.Tensor":
        """
        Restore time-shifted layout to per-timestep C-channel arrangement and reverse-offset first codebook.

        Converts the time-shifted multi-channel output back to standard (batch, time, channels) format
        and maps the first codebook tokens back to their original space by subtracting the speech token offset.

        Args:
            output_ids: Time-shifted output tensor.
            speech_token_range: Speech token range for reverse mapping.
            max_channels: Number of codebook channels.

        Returns:
            Restored tensor with shape (batch, seq_len, max_channels).
        """
        seq_len = output_ids.shape[1] - max_channels + 1
        speech_ids = torch.full((output_ids.shape[0], seq_len, max_channels), 0, dtype=output_ids.dtype, device=output_ids.device)
        for j in range(max_channels):
            speech_ids[..., j] = output_ids[:, j : seq_len + j, j]
            if j == 0:
                speech_ids[..., j] = speech_ids[..., j] - speech_token_range[0]
        return speech_ids

    def _find_max_valid_positions(self, data: "torch.Tensor", invalid_value: int = 1024):
        """
        Locate continuous valid audio segment intervals in each sequence (all non-text channels valid simultaneously).

        Identifies contiguous spans where all audio channels (columns 1+) contain valid tokens
        (not the invalid_value padding token).

        Args:
            data: Input tensor with shape (batch, time, channels).
            invalid_value: Token ID considered as invalid/padding.

        Returns:
            List of lists containing valid audio segments for each sequence in the batch.
        """
        mask = torch.all(data[:, :, 1:] != invalid_value, dim=2)
        valid_indices = torch.where(mask)
        result = [[] for _ in range(len(data))]
        if valid_indices[0].numel() == 0:
            return result
        grouped = []
        group_ids = []
        for i, seq_no in enumerate(valid_indices[0]):
            pos = valid_indices[1][i]
            if not group_ids or seq_no > group_ids[-1]:
                group_ids.append(seq_no)
                grouped.append([[pos, pos + 1]])
            elif pos == grouped[-1][-1][-1]:
                grouped[-1][-1][-1] += 1
            else:
                grouped[-1].append([pos, pos + 1])
        for gid, spans in zip(group_ids, grouped):
            for s, e in spans:
                result[gid].append(data[gid, s:e, :])
        return result

    def batch_decode(self, token_ids: "torch.Tensor", *args, **kwargs):
        """
        Decode a batch of token sequences into text and audio outputs.

        This method forwards the `token_ids` and `kwargs` arguments to decode text and audio outputs from the model.
        Please refer to the docstring of the respective methods for more information.

        Args:
            token_ids (`torch.Tensor`):
                Token tensor with shape (batch, time, channels).
            *args:
                Additional arguments passed to tokenizer.batch_decode.
            **kwargs:
                Additional keyword arguments passed to tokenizer.batch_decode.

        Returns:
            `tuple`: Tuple of (text_list, audio_list) where text_list contains decoded text strings and audio_list
                contains decoded audio arrays for each sequence.
        """
        assert token_ids.ndim == 3 and token_ids.shape[2] == self.max_channels
        text = self.tokenizer.batch_decode(token_ids[:, :, 0], *args, **kwargs)
        normal = self.shifting_outputs(token_ids, self.speech_token_range, self.max_channels)
        audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id)
        decode_audio = []
        for seq_frags in audio_frags:
            if len(seq_frags):
                frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in seq_frags], dim=1)
                decode_audio.append(self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"])
            else:
                decode_audio.append([])
        return text, decode_audio

    def decode(self, token_ids: "torch.Tensor", *args, **kwargs) -> MossTTSDResponse:
        """
        Decode a single sequence of token IDs into text and audio.

        This method forwards the `token_ids` and `kwargs` arguments to decode a single sequence. Please refer to the
        docstring of the respective methods for more information.

        Args:
            token_ids (`torch.Tensor`):
                Token tensor with shape (time, channels).
            *args:
                Additional arguments passed to tokenizer.decode.
            **kwargs:
                Additional keyword arguments passed to tokenizer.decode.

        Returns:
            [`MossTTSDResponse`]: Response object containing generated text, audio, and sampling rate.
        """
        assert token_ids.ndim == 2 and token_ids.shape[1] == self.max_channels
        text = self.tokenizer.decode(token_ids[:, 0].squeeze(-1), *args, **kwargs)
        normal = self.shifting_outputs(token_ids.unsqueeze(0), self.speech_token_range, self.max_channels)
        audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id)[0]
        if len(audio_frags):
            frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in audio_frags], dim=1)
            audio = self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"]
        else:
            audio = None
        return MossTTSDResponse(
            audio=None if audio is None else audio.detach().cpu().numpy(),
            generated_text=text,
            sampling_rate=self.output_sample_rate,
        )

    def save_audio(self, audios, output_dir="output", prefix="audio"):
        """
        Save multiple audio fragments to files.
        
        Args:
            audios: List of audio data fragments from batch_decode
            output_dir (str): Directory to save audio files
            prefix (str): Prefix for audio filenames
        """
        if not is_torchaudio_available():
            raise ImportError("Please install `torchaudio` to save audio files.")
        
        os.makedirs(output_dir, exist_ok=True)

        for i, data in enumerate(audios):
            for j, fragment in enumerate(data):
                filename = f"{output_dir}/{prefix}_{i}_{j}.wav"
                torchaudio.save(filename, fragment.cpu(), self.output_sample_rate)


__all__ = ["MossTTSDProcessor"]