# coding=utf-8 # Copyright 2025 OpenMOSS and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Processor class for MOSS-TTSD. """ from __future__ import annotations import math import os import re from dataclasses import asdict, dataclass from typing import Any, Callable, Optional, Union import numpy as np from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import BatchEncoding from transformers.utils import is_torch_available, is_torchaudio_available from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModel #from transformers.models.xy_tokenizer.modeling_xy_tokenizer import XYTokenizer if is_torch_available(): import torch if is_torchaudio_available(): import torchaudio class MossTTSDProcessorKwargs(ProcessingKwargs, total=False): """ Arguments for configuring MOSS-TTSD processing operations. Inherits from ProcessingKwargs and provides structured configuration for text and audio processing. """ _defaults = { "text_kwargs": { "pad_token_id": 0, # Fallback pad token ID, actual value comes from tokenizer.pad_token_id }, "audio_kwargs": { "max_channels": 8, # Maximum number of quantization channels "audio_pad_token_id": 1024, # Padding token ID for non-text channels "silence_duration": 0.0, # Duration of silence to append for encoder segmentation "input_sample_rate": 16000, # Input audio sampling rate (fallback, inferred from audio_tokenizer.config) "encoder_downsample_rate": 320, # Encoder downsampling rate (fallback, inferred from audio_tokenizer.config) "speech_token_range": [151665, 152689], # Token range for speech tokens (first codebook offset mapping) "audio_bos_token": "<|begin_of_speech|>", "audio_eos_token": "<|end_of_speech|>", }, "common_kwargs": { "return_tensors": "pt", "padding": True, "use_normalize": False, }, } @dataclass class MossTTSDChatSample: """ Intermediate representation of a single sample with T×C grid layout and metadata. Args: input_ids_2d (`torch.LongTensor`): Shape (T, C) tensor where column 0 contains text tokens and columns 1..C-1 contain quantized audio codebooks (or padding token 1024 for empty slots). label_ids_2d (`torch.LongTensor`, *optional*): Optional label tensor for training, same shape as input_ids_2d. meta (`dict`): Dictionary containing metadata for debugging and tracking purposes. """ input_ids_2d: "torch.LongTensor" label_ids_2d: Optional["torch.LongTensor"] meta: dict @dataclass class MossTTSDBatchInput: """ Batched input tensors for MOSS-TTSD model. Args: input_ids (`torch.LongTensor`): Shape (B, T, C) tensor containing batched input token IDs. attention_mask (`torch.LongTensor`): Shape (B, T) tensor containing attention mask for valid tokens. labels (`torch.LongTensor`, *optional*): Optional shape (B, T, C) tensor containing label token IDs for training. """ input_ids: "torch.LongTensor" attention_mask: "torch.LongTensor" labels: Optional["torch.LongTensor"] @dataclass class MossTTSDResponse: """ Unified response container for MOSS-TTSD inference outputs. Args: audio (`np.ndarray`, *optional*): Optional numpy array containing generated audio waveform. generated_text (`str`, *optional*, defaults to `""`): String containing generated text output. sampling_rate (`int`, *optional*): Optional integer specifying the sampling rate of the generated audio. """ audio: Optional[np.ndarray] = None generated_text: str = "" sampling_rate: Optional[int] = None class MossTTSDSampleProcessor: """ Sample-level processor for MOSS-TTSD that handles individual sample processing without batch padding. This class handles per-sample processing logic: - Parses JSONL items (text/prompt_text/prompt_audio) - Optional text normalization - Audio loading/resampling/merging, feature extraction and encoding - Generates T×C grid and performs multi-channel shifting Args: tokenizer (`AutoTokenizer`): The text tokenizer for encoding text tokens. feature_extractor (`AutoFeatureExtractor`, *optional*): Optional feature extractor for audio preprocessing. audio_tokenizer (`AutoModel`, *optional*): Optional audio tokenizer for audio encoding/decoding. chat_template (`str`, *optional*): Optional chat template string for conversation formatting. speech_token_range (`List[int]`): List of [start, end] token IDs for speech token mapping. audio_bos_token (`str`): Beginning of speech token string. audio_eos_token (`str`): End of speech token string. audio_pad_token_id (`int`): Padding token ID for audio channels. max_channels (`int`): Maximum number of quantization channels. input_sample_rate (`int`): Target sample rate for input audio. encoder_downsample_rate (`int`): Downsampling rate of the audio encoder. """ def __init__( self, tokenizer, feature_extractor: Optional = None, audio_tokenizer: Optional = None, *, chat_template: Optional[str], speech_token_range: list[int], audio_bos_token: str, audio_eos_token: str, audio_pad_token_id: int, max_channels: int, input_sample_rate: int, encoder_downsample_rate: int, ) -> None: self.tokenizer = tokenizer self.feature_extractor = feature_extractor self.audio_tokenizer = audio_tokenizer self.chat_template = chat_template self.speech_token_range = speech_token_range self.audio_bos_token = audio_bos_token self.audio_eos_token = audio_eos_token self.audio_pad_token_id = audio_pad_token_id self.max_channels = max_channels self.input_sample_rate = input_sample_rate self.encoder_downsample_rate = encoder_downsample_rate def prepare_sample( self, item: dict[str, Any], *, apply_chat_template: Callable[[str, dict], str], use_normalize: bool = False, silence_duration: float = 0.0, **kwargs, ) -> MossTTSDChatSample: """ Prepare a single sample from JSONL item into MossTTSDChatSample format. Args: item (`dict`): Dictionary containing the input data (text, prompt_audio, etc.). apply_chat_template (`callable`): Function to apply chat template formatting. use_normalize (`bool`, *optional*, defaults to `False`): Whether to apply text normalization. silence_duration (`float`, *optional*, defaults to `0.0`): Duration of silence to append to audio for encoder segmentation. **kwargs: Additional keyword arguments passed to chat template. Returns: `MossTTSDChatSample`: Processed sample with 2D input tensor and metadata. """ processed = self._process_jsonl_item(item) system_prompt = item.get("system_prompt") if isinstance(system_prompt, str): kwargs["system_prompt"] = system_prompt full_text = (processed["prompt_text"] or "") + processed["text"] original_full_text = full_text if use_normalize: full_text = self._normalize_text(full_text) final_text = full_text.replace("[S1]", "").replace("[S2]", "") # Load and resample audio (may be None) wav = self._process_audio_data(processed["prompt_audio"], target_sample_rate=self.input_sample_rate) # Assemble into grid (T, C) inputs_2d = self._build_inputs( text=final_text, audio_data=wav, apply_chat_template=apply_chat_template, silence_duration=silence_duration, **kwargs, ) inputs_2d = self._shift_inputs(inputs_2d, pad_token_id=self.tokenizer.pad_token_id, max_channels=self.max_channels) meta = { "original_text": original_full_text, "normalized_text": self._normalize_text(original_full_text) if use_normalize else None, "final_text": final_text, "use_normalize": use_normalize, } ids_t = torch.tensor(inputs_2d, dtype=torch.long) return MossTTSDChatSample(input_ids_2d=ids_t, label_ids_2d=None, meta=meta) def collate( self, samples: list[MossTTSDChatSample], *, pad_token_id: int, audio_pad_token_id: int, ) -> MossTTSDBatchInput: """ Collate multiple samples into a batch with proper padding. Args: samples (`List[MossTTSDChatSample]`): List of MossTTSDChatSample objects to collate. pad_token_id (`int`): Padding token ID for text tokens. audio_pad_token_id (`int`): Padding token ID for audio tokens. Returns: `MossTTSDBatchInput`: Batched input with padded tensors. """ assert is_torch_available(), "PyTorch is required for collation." ids_list = [s.input_ids_2d for s in samples] labels_list = [s.label_ids_2d for s in samples] C = ids_list[0].shape[1] max_len = max(x.shape[0] for x in ids_list) padded_ids, padded_labels, padded_attn = [], [], [] for ids, labels in zip(ids_list, labels_list): pad_len = max_len - ids.shape[0] pad_grid = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long) pad_grid[:, 0] = pad_token_id # Text column uses tokenizer pad ids_padded = torch.cat([pad_grid, ids], dim=0) padded_ids.append(ids_padded) attn = torch.ones(ids.shape[0], dtype=torch.long) a_pad = torch.zeros(pad_len, dtype=torch.long) padded_attn.append(torch.cat([a_pad, attn], dim=0)) if labels is None: padded_labels.append(None) else: lab_pad = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long) lab_pad[:, 0] = -100 # Text labels are ignored by default padded_labels.append(torch.cat([lab_pad, labels], dim=0)) input_ids = torch.stack(padded_ids) # (B, T, C) attention_mask = torch.stack(padded_attn) # (B, T) labels = torch.stack([l if l is not None else torch.full_like(input_ids[0], -100) for l in padded_labels]) \ if any(l is not None for l in padded_labels) else None return MossTTSDBatchInput(input_ids=input_ids, attention_mask=attention_mask, labels=labels) @staticmethod def _process_jsonl_item(item: dict[str, Any]) -> dict[str, Any]: """ Process a JSONL item to extract text and audio data. Supports both single-speaker and multi-speaker formats: - Single: {"prompt_audio": path, "prompt_text": text} - Multi: {"prompt_audio_speaker1": path1, "prompt_text_speaker1": text1, ...} Args: item: Dictionary containing the JSONL item data. Returns: Dictionary with extracted "text", "prompt_text", and "prompt_audio" fields. """ base_path = item.get("base_path", "") text = item.get("text", "") prompt_audio = None prompt_text = "" if "prompt_audio" in item and "prompt_text" in item: pa = item["prompt_audio"] if pa: prompt_audio = os.path.join(base_path, pa) if isinstance(pa, str) and base_path else pa prompt_text = item.get("prompt_text", "") else: pa1, pt1 = item.get("prompt_audio_speaker1", ""), item.get("prompt_text_speaker1", "") pa2, pt2 = item.get("prompt_audio_speaker2", ""), item.get("prompt_text_speaker2", "") has1 = (isinstance(pa1, str) and pa1) or isinstance(pa1, tuple) has2 = (isinstance(pa2, str) and pa2) or isinstance(pa2, tuple) if has1 or has2: spk1 = os.path.join(base_path, pa1) if isinstance(pa1, str) and base_path and pa1 else pa1 spk2 = os.path.join(base_path, pa2) if isinstance(pa2, str) and base_path and pa2 else pa2 prompt_audio = {"speaker1": spk1, "speaker2": spk2} tmp = "" if pt1: tmp += f"[S1]{pt1}" if pt2: tmp += f"[S2]{pt2}" prompt_text = tmp.strip() return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} @staticmethod def _normalize_text(text: str) -> str: """ Normalize text by applying various transformations for TTS processing. Performs speaker tag conversion, punctuation normalization, laughter conversion, and other text cleaning operations suitable for speech synthesis. Args: text: Input text string to normalize. Returns: Normalized text string. """ text = re.sub(r"\[(\d+)\]", r"[S\1]", text) remove_chars = '【】《》()『』「」"-""~~' text = re.sub(r"\[(?!S\d+\])([^\]]*)\]", r"\1", text) segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " ")) out = [] for seg in segments: seg = seg.strip() if not seg: continue m = re.match(r"^(\[S\d+\])\s*(.*)", seg) tag, content = m.groups() if m else ("", seg) content = re.sub(f"[{re.escape(remove_chars)}]", "", content) content = re.sub(r"哈{2,}", "(笑)", content) content = re.sub(r"\b(ha(\s*ha)+)\b", "(laughs)", content, flags=re.IGNORECASE) content = content.replace("——", ",").replace("……", ",") trans = str.maketrans({"!": ",", "!": ",", ";": ",", ";": ",", ":": ",", ":": ",", "、": ",", "?": ",", "?": ","}) content = content.translate(trans).strip() if len(content) > 1: last = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) body = content[:-1].replace("。", ",") content = body + last out.append(f"{tag}{content}".strip()) return "".join(out) @staticmethod def _load_single_audio(audio_input: Union[str, tuple["torch.Tensor", int]]): """ Load audio from file path or tensor tuple. Args: audio_input: Either a file path string or a tuple of (tensor, sample_rate). Returns: Tuple of (audio_tensor, sample_rate). Raises: ValueError: If audio input format is unsupported. """ if isinstance(audio_input, tuple) and len(audio_input) == 2: return audio_input if isinstance(audio_input, str): try: return torchaudio.load(audio_input) except Exception: import soundfile as sf # type: ignore data, sr = sf.read(audio_input, always_2d=True) data_t = torch.from_numpy(np.transpose(data)) # (C, T) return data_t, int(sr) raise ValueError(f"Unsupported audio input format: {type(audio_input)}") @staticmethod def _resample(audio: "torch.Tensor", sr: int, target_sr: int) -> tuple["torch.Tensor", int]: """ Resample audio to target sample rate and convert to mono if needed. Args: audio: Input audio tensor with shape (channels, time). sr: Current sample rate. target_sr: Target sample rate. Returns: Tuple of (resampled_audio, target_sr) where audio is mono with shape (1, time). """ if sr != target_sr: audio = torchaudio.functional.resample(audio, sr, target_sr) if audio.shape[0] > 1: audio = audio.mean(dim=0, keepdim=True) if audio.ndim == 1: audio = audio.unsqueeze(0) return audio, target_sr @classmethod def _load_audio_data( cls, audio_input: Union[str, tuple["torch.Tensor", int]], target_sample_rate: int ) -> tuple["torch.Tensor", int]: """ Load and resample audio data to target sample rate. Args: audio_input: Audio file path or tensor tuple. target_sample_rate: Target sample rate for resampling. Returns: Tuple of (audio_tensor, target_sample_rate). """ audio, sr = cls._load_single_audio(audio_input) return cls._resample(audio, sr, target_sample_rate) @classmethod def _merge_speaker_audios( cls, wav1: Union[str, tuple["torch.Tensor", int]], wav2: Union[str, tuple["torch.Tensor", int]], target_sample_rate: int, ) -> "torch.Tensor": """ Merge two speaker audio inputs by concatenation. Args: wav1: Audio input for speaker 1. wav2: Audio input for speaker 2. target_sample_rate: Target sample rate for both audio inputs. Returns: Concatenated audio tensor. """ a1, _ = cls._load_audio_data(wav1, target_sample_rate) a2, _ = cls._load_audio_data(wav2, target_sample_rate) return torch.cat([a1, a2], dim=1) @classmethod def _process_audio_data( cls, prompt_audio: Optional[Union[str, dict[str, Any], tuple["torch.Tensor", int]]], target_sample_rate: int ) -> Optional["torch.Tensor"]: """ Process audio data from various input formats. Handles single audio files, multi-speaker audio dictionaries, or None input. Args: prompt_audio: Audio input in various formats (path, dict, tensor tuple, or None). target_sample_rate: Target sample rate for processing. Returns: Processed audio tensor or None if no audio provided. """ if prompt_audio is None: return None if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio: return cls._merge_speaker_audios(prompt_audio["speaker1"], prompt_audio["speaker2"], target_sample_rate) wav, _ = cls._load_audio_data(prompt_audio, target_sample_rate) return wav def _build_inputs( self, text: str, audio_data: Optional["torch.Tensor"], apply_chat_template: Callable[[str, dict], str], silence_duration: float, **kwargs, ) -> np.ndarray: """ Build input grid from text and optional audio data. Creates a TxC grid where column 0 contains text tokens and columns 1..C-1 contain quantized audio codebook tokens. Audio tokens are mapped to speech token range. Args: text: Input text string to process. audio_data: Optional audio tensor with shape (channels, time). apply_chat_template: Function to apply chat template formatting. silence_duration: Duration of silence to append for encoder segmentation. **kwargs: Additional arguments for chat template. Returns: NumPy array with shape (T, max_channels) containing the input grid. """ assert isinstance(text, str), "text must be a string" prompt = apply_chat_template(text, kwargs) text_ids = np.array(self.tokenizer.encode(prompt, add_special_tokens=False)) grid = np.full((text_ids.shape[0], self.max_channels), self.audio_pad_token_id, dtype=np.int64) grid[:, 0] = text_ids if audio_data is not None: silence_samples = int(max(0.0, silence_duration) * self.input_sample_rate) silence = torch.zeros(audio_data.shape[0], silence_samples, device=audio_data.device) wav = torch.cat([audio_data, silence], dim=1) feat = self.feature_extractor( wav, sampling_rate=self.input_sample_rate, return_attention_mask=True, return_tensors="pt" ) with torch.no_grad(): enc = self.audio_tokenizer.encode(feat) # (time, codebooks) audio_codes = enc["audio_codes"][:, 0].permute(1, 0).cpu().numpy() # Map first codebook to speech token range audio_codes[:, 0] = audio_codes[:, 0] + self.speech_token_range[0] grid = np.concatenate([grid, audio_codes], axis=0) # Trim silence tokens at the end based on encoder downsampling silence_tokens = silence_duration * self.input_sample_rate / self.encoder_downsample_rate cut = math.floor(silence_tokens / 10) * 10 if cut > 0: grid = grid[:-cut] return grid @staticmethod def _shift_inputs(input_ids: np.ndarray, pad_token_id: int, max_channels: int) -> np.ndarray: """ Convert (T, C) grid to time-shifted multi-channel layout (preserving original implementation logic). Creates a shifted layout where new_len = T + C - 1, with column j shifted backwards by j positions. This enables the model to process multiple codebook channels with temporal alignment. Args: input_ids: Input grid with shape (T, C). pad_token_id: Padding token ID for text tokens. max_channels: Maximum number of channels. Returns: Shifted array with shape (T + max_channels - 1, max_channels). """ T, _ = input_ids.shape new_len = T + max_channels - 1 shifted = np.full((new_len, max_channels), fill_value=1024, dtype=np.int64) shifted[:, 0] = np.full(new_len, pad_token_id, dtype=np.int64) for j in range(max_channels): shifted[j : (T + j), j] = input_ids[:, j] return shifted class MossTTSDProcessor(ProcessorMixin): r""" Constructs a MOSS-TTSD processor which wraps a tokenizer, feature extractor, and audio tokenizer into a single processor. It provides unified text-speech processing capabilities while maintaining backward compatibility with previous API versions. [`MossTTSDProcessor`] offers all the functionalities of [`AutoTokenizer`], [`AutoFeatureExtractor`] and [`XYTokenizer`]. See the [`~MossTTSDProcessor.__call__`] and [`~MossTTSDProcessor.decode`] for more information. Args: tokenizer ([`AutoTokenizer`]): An instance of [`AutoTokenizer`]. The tokenizer is a required input. feature_extractor ([`AutoFeatureExtractor`]): An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input. audio_tokenizer ([`XYTokenizer`]): An instance of [`XYTokenizer`]. The audio tokenizer is a required input. chat_template (`str`, *optional*): A template string for chat formatting when combining text and audio interactions. speech_token_range (`List[int]`, *optional*, defaults to `[151665, 152689]`): Token range [start, end] for mapping speech tokens. audio_bos_token (`str`, *optional*, defaults to `"<|begin_of_speech|>"`): Beginning of speech token string. audio_eos_token (`str`, *optional*, defaults to `"<|end_of_speech|>"`): End of speech token string. audio_pad_token_id (`int`, *optional*, defaults to `1024`): Padding token ID for audio channels. """ feature_extractor_class = "AutoFeatureExtractor" tokenizer_class = "AutoTokenizer" audio_tokenizer_class = "PreTrainedModel" def __init__( self, tokenizer, feature_extractor, audio_tokenizer, chat_template: Optional[str] = None, speech_token_range: Optional[list[int]] = None, audio_bos_token: str = "<|begin_of_speech|>", audio_eos_token: str = "<|end_of_speech|>", audio_pad_token_id: int = 1024, **kwargs, ) -> None: super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer, **kwargs) self.max_channels = (audio_tokenizer.quantizer.num_quantizers if audio_tokenizer else None) or 8 self.input_sample_rate = (getattr(audio_tokenizer, "config", None).input_sample_rate if audio_tokenizer else None) or 16000 self.output_sample_rate = (getattr(audio_tokenizer, "config", None).output_sample_rate if audio_tokenizer else None) or 16000 self.encoder_downsample_rate = (getattr(audio_tokenizer, "config", None).encoder_downsample_rate if audio_tokenizer else None) or 320 # Use tokenizer's built-in chat template as primary self.chat_template = getattr(tokenizer, "chat_template", None) or chat_template # Read speech token range from tokenizer with fallback self.speech_token_range = ( getattr(tokenizer, "speech_token_range", None) or speech_token_range or [151665, 152689] ) self.audio_bos_token = getattr(tokenizer, "audio_bos_token", None) or audio_bos_token self.audio_eos_token = getattr(tokenizer, "audio_eos_token", None) or audio_eos_token self.audio_pad_token_id = getattr(tokenizer, "audio_pad_token_id", None) or audio_pad_token_id # Sample-level processor self.sample_processor = MossTTSDSampleProcessor( tokenizer=self.tokenizer, feature_extractor=self.feature_extractor, audio_tokenizer=self.audio_tokenizer, chat_template=self.chat_template, speech_token_range=self.speech_token_range, audio_bos_token=self.audio_bos_token, audio_eos_token=self.audio_eos_token, audio_pad_token_id=self.audio_pad_token_id, max_channels=self.max_channels, input_sample_rate=self.input_sample_rate, encoder_downsample_rate=self.encoder_downsample_rate, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], trust_remote_code=True, **kwargs): """ Instantiate a processor from a pretrained model. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): The name of or path to the pretrained model. **kwargs: Additional keyword arguments passed to the respective component loaders. Returns: [`MossTTSDProcessor`]: A new instance of the processor. """ kwargs.pop("_from_auto") audio_tokenizer_path = kwargs.pop("codec_path", os.path.join(pretrained_model_name_or_path, "XY_Tokenizer")) assert isinstance(audio_tokenizer_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_path)}" tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) feature_extractor = AutoFeatureExtractor.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs) audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs) return cls( tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer, **kwargs, ) @classmethod def get_processor_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> tuple[dict[str, Any], dict[str, Any]]: proc_dict, rest = super().get_processor_dict(pretrained_model_name_or_path, **kwargs) if "audio_tokenizer" in rest: proc_dict["audio_tokenizer"] = rest.pop("audio_tokenizer") for key in ("speech_token_range", "audio_bos_token", "audio_eos_token", "audio_pad_token_id"): if key in rest: proc_dict[key] = rest.pop(key) return proc_dict, rest def __call__( self, data: Union[dict[str, Any], list[dict[str, Any]]], **kwargs: Unpack[MossTTSDProcessorKwargs], ) -> BatchEncoding: """ Main method to prepare inputs for the model from structured data. This method forwards the `data` and `kwargs` arguments to prepare inputs for MOSS-TTSD model. Please refer to the docstring of the respective methods for more information. Args: data (`dict` or `list[dict]`): Single dictionary or list of dictionaries containing input data. Expected keys include 'text', 'prompt_text', 'prompt_audio', etc. **kwargs (`MossTTSDProcessorKwargs`): Additional processing arguments. Returns: [`BatchEncoding`]: Processed inputs ready for model consumption. """ if isinstance(data, dict): data = [data] out_kwargs = self._merge_kwargs(MossTTSDProcessorKwargs, **kwargs) text_kwargs = out_kwargs["text_kwargs"] audio_kwargs = out_kwargs["audio_kwargs"] common_kwargs = out_kwargs["common_kwargs"] return_tensors = common_kwargs.get("return_tensors", "pt") padding = common_kwargs.get("padding", True) use_normalize = common_kwargs.get("use_normalize", False) pad_token_id = int(text_kwargs.get("pad_token_id", self.tokenizer.pad_token_id or 0)) max_channels = int(audio_kwargs.get("max_channels", self.max_channels)) audio_pad_token_id = int(audio_kwargs.get("audio_pad_token_id", self.audio_pad_token_id)) silence_duration = float(audio_kwargs.get("silence_duration", 0.0)) def _apply_chat_template(text: str, extra: dict) -> str: return self.apply_chat_template(conversation=None, text=text, **extra) samples: list[MossTTSDChatSample] = [] for item in data: sample = self.sample_processor.prepare_sample( item, apply_chat_template=_apply_chat_template, use_normalize=use_normalize, silence_duration=silence_duration, ) # Override with call-time max_channels (may differ from component initialization) if sample.input_ids_2d.shape[1] != max_channels: # Simplified: for clipping/extending channels, only pad/clip on the right side T, C = sample.input_ids_2d.shape if C > max_channels: sample.input_ids_2d = sample.input_ids_2d[:, :max_channels] else: pad = torch.full((T, max_channels - C), audio_pad_token_id, dtype=torch.long) sample.input_ids_2d = torch.cat([sample.input_ids_2d, pad], dim=1) samples.append(sample) if not padding: raise NotImplementedError("Unpadded batches are not supported yet.") batch = self.sample_processor.collate( samples, pad_token_id=pad_token_id, audio_pad_token_id=audio_pad_token_id, ) # Align with HiggsAudioProcessor: explicit dict -> BatchEncoding/Feature inputs = asdict(batch) inputs = {k: v for k, v in inputs.items() if v is not None} return BatchEncoding(inputs, tensor_type=return_tensors) def shifting_outputs( self, output_ids: "torch.Tensor", speech_token_range: list[int], max_channels: int = 8, ) -> "torch.Tensor": """ Restore time-shifted layout to per-timestep C-channel arrangement and reverse-offset first codebook. Converts the time-shifted multi-channel output back to standard (batch, time, channels) format and maps the first codebook tokens back to their original space by subtracting the speech token offset. Args: output_ids: Time-shifted output tensor. speech_token_range: Speech token range for reverse mapping. max_channels: Number of codebook channels. Returns: Restored tensor with shape (batch, seq_len, max_channels). """ seq_len = output_ids.shape[1] - max_channels + 1 speech_ids = torch.full((output_ids.shape[0], seq_len, max_channels), 0, dtype=output_ids.dtype, device=output_ids.device) for j in range(max_channels): speech_ids[..., j] = output_ids[:, j : seq_len + j, j] if j == 0: speech_ids[..., j] = speech_ids[..., j] - speech_token_range[0] return speech_ids def _find_max_valid_positions(self, data: "torch.Tensor", invalid_value: int = 1024): """ Locate continuous valid audio segment intervals in each sequence (all non-text channels valid simultaneously). Identifies contiguous spans where all audio channels (columns 1+) contain valid tokens (not the invalid_value padding token). Args: data: Input tensor with shape (batch, time, channels). invalid_value: Token ID considered as invalid/padding. Returns: List of lists containing valid audio segments for each sequence in the batch. """ mask = torch.all(data[:, :, 1:] != invalid_value, dim=2) valid_indices = torch.where(mask) result = [[] for _ in range(len(data))] if valid_indices[0].numel() == 0: return result grouped = [] group_ids = [] for i, seq_no in enumerate(valid_indices[0]): pos = valid_indices[1][i] if not group_ids or seq_no > group_ids[-1]: group_ids.append(seq_no) grouped.append([[pos, pos + 1]]) elif pos == grouped[-1][-1][-1]: grouped[-1][-1][-1] += 1 else: grouped[-1].append([pos, pos + 1]) for gid, spans in zip(group_ids, grouped): for s, e in spans: result[gid].append(data[gid, s:e, :]) return result def batch_decode(self, token_ids: "torch.Tensor", *args, **kwargs): """ Decode a batch of token sequences into text and audio outputs. This method forwards the `token_ids` and `kwargs` arguments to decode text and audio outputs from the model. Please refer to the docstring of the respective methods for more information. Args: token_ids (`torch.Tensor`): Token tensor with shape (batch, time, channels). *args: Additional arguments passed to tokenizer.batch_decode. **kwargs: Additional keyword arguments passed to tokenizer.batch_decode. Returns: `tuple`: Tuple of (text_list, audio_list) where text_list contains decoded text strings and audio_list contains decoded audio arrays for each sequence. """ assert token_ids.ndim == 3 and token_ids.shape[2] == self.max_channels text = self.tokenizer.batch_decode(token_ids[:, :, 0], *args, **kwargs) normal = self.shifting_outputs(token_ids, self.speech_token_range, self.max_channels) audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id) decode_audio = [] for seq_frags in audio_frags: if len(seq_frags): frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in seq_frags], dim=1) decode_audio.append(self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"]) else: decode_audio.append([]) return text, decode_audio def decode(self, token_ids: "torch.Tensor", *args, **kwargs) -> MossTTSDResponse: """ Decode a single sequence of token IDs into text and audio. This method forwards the `token_ids` and `kwargs` arguments to decode a single sequence. Please refer to the docstring of the respective methods for more information. Args: token_ids (`torch.Tensor`): Token tensor with shape (time, channels). *args: Additional arguments passed to tokenizer.decode. **kwargs: Additional keyword arguments passed to tokenizer.decode. Returns: [`MossTTSDResponse`]: Response object containing generated text, audio, and sampling rate. """ assert token_ids.ndim == 2 and token_ids.shape[1] == self.max_channels text = self.tokenizer.decode(token_ids[:, 0].squeeze(-1), *args, **kwargs) normal = self.shifting_outputs(token_ids.unsqueeze(0), self.speech_token_range, self.max_channels) audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id)[0] if len(audio_frags): frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in audio_frags], dim=1) audio = self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"] else: audio = None return MossTTSDResponse( audio=None if audio is None else audio.detach().cpu().numpy(), generated_text=text, sampling_rate=self.output_sample_rate, ) def save_audio(self, audios, output_dir="output", prefix="audio"): """ Save multiple audio fragments to files. Args: audios: List of audio data fragments from batch_decode output_dir (str): Directory to save audio files prefix (str): Prefix for audio filenames """ if not is_torchaudio_available(): raise ImportError("Please install `torchaudio` to save audio files.") os.makedirs(output_dir, exist_ok=True) for i, data in enumerate(audios): for j, fragment in enumerate(data): filename = f"{output_dir}/{prefix}_{i}_{j}.wav" torchaudio.save(filename, fragment.cpu(), self.output_sample_rate) __all__ = ["MossTTSDProcessor"]