|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for MOSS-TTSD. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import math |
|
import os |
|
import re |
|
from dataclasses import asdict, dataclass |
|
from typing import Any, Callable, Optional, Union |
|
|
|
import numpy as np |
|
|
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
from transformers.utils import is_torch_available, is_torchaudio_available |
|
from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
if is_torch_available(): |
|
import torch |
|
|
|
if is_torchaudio_available(): |
|
import torchaudio |
|
|
|
|
|
class MossTTSDProcessorKwargs(ProcessingKwargs, total=False): |
|
""" |
|
Arguments for configuring MOSS-TTSD processing operations. |
|
|
|
Inherits from ProcessingKwargs and provides structured configuration for text and audio processing. |
|
""" |
|
|
|
_defaults = { |
|
"text_kwargs": { |
|
"pad_token_id": 0, |
|
}, |
|
"audio_kwargs": { |
|
"max_channels": 8, |
|
"audio_pad_token_id": 1024, |
|
"silence_duration": 0.0, |
|
"input_sample_rate": 16000, |
|
"encoder_downsample_rate": 320, |
|
"speech_token_range": [151665, 152689], |
|
"audio_bos_token": "<|begin_of_speech|>", |
|
"audio_eos_token": "<|end_of_speech|>", |
|
}, |
|
"common_kwargs": { |
|
"return_tensors": "pt", |
|
"padding": True, |
|
"use_normalize": False, |
|
}, |
|
} |
|
|
|
|
|
@dataclass |
|
class MossTTSDChatSample: |
|
""" |
|
Intermediate representation of a single sample with T×C grid layout and metadata. |
|
|
|
Args: |
|
input_ids_2d (`torch.LongTensor`): |
|
Shape (T, C) tensor where column 0 contains text tokens and columns 1..C-1 contain |
|
quantized audio codebooks (or padding token 1024 for empty slots). |
|
label_ids_2d (`torch.LongTensor`, *optional*): |
|
Optional label tensor for training, same shape as input_ids_2d. |
|
meta (`dict`): |
|
Dictionary containing metadata for debugging and tracking purposes. |
|
""" |
|
|
|
input_ids_2d: "torch.LongTensor" |
|
label_ids_2d: Optional["torch.LongTensor"] |
|
meta: dict |
|
|
|
@dataclass |
|
class MossTTSDBatchInput: |
|
""" |
|
Batched input tensors for MOSS-TTSD model. |
|
|
|
Args: |
|
input_ids (`torch.LongTensor`): |
|
Shape (B, T, C) tensor containing batched input token IDs. |
|
attention_mask (`torch.LongTensor`): |
|
Shape (B, T) tensor containing attention mask for valid tokens. |
|
labels (`torch.LongTensor`, *optional*): |
|
Optional shape (B, T, C) tensor containing label token IDs for training. |
|
""" |
|
|
|
input_ids: "torch.LongTensor" |
|
attention_mask: "torch.LongTensor" |
|
labels: Optional["torch.LongTensor"] |
|
|
|
|
|
@dataclass |
|
class MossTTSDResponse: |
|
""" |
|
Unified response container for MOSS-TTSD inference outputs. |
|
|
|
Args: |
|
audio (`np.ndarray`, *optional*): |
|
Optional numpy array containing generated audio waveform. |
|
generated_text (`str`, *optional*, defaults to `""`): |
|
String containing generated text output. |
|
sampling_rate (`int`, *optional*): |
|
Optional integer specifying the sampling rate of the generated audio. |
|
""" |
|
|
|
audio: Optional[np.ndarray] = None |
|
generated_text: str = "" |
|
sampling_rate: Optional[int] = None |
|
|
|
|
|
class MossTTSDSampleProcessor: |
|
""" |
|
Sample-level processor for MOSS-TTSD that handles individual sample processing without batch padding. |
|
|
|
This class handles per-sample processing logic: |
|
- Parses JSONL items (text/prompt_text/prompt_audio) |
|
- Optional text normalization |
|
- Audio loading/resampling/merging, feature extraction and encoding |
|
- Generates T×C grid and performs multi-channel shifting |
|
|
|
Args: |
|
tokenizer (`AutoTokenizer`): |
|
The text tokenizer for encoding text tokens. |
|
feature_extractor (`AutoFeatureExtractor`, *optional*): |
|
Optional feature extractor for audio preprocessing. |
|
audio_tokenizer (`AutoModel`, *optional*): |
|
Optional audio tokenizer for audio encoding/decoding. |
|
chat_template (`str`, *optional*): |
|
Optional chat template string for conversation formatting. |
|
speech_token_range (`List[int]`): |
|
List of [start, end] token IDs for speech token mapping. |
|
audio_bos_token (`str`): |
|
Beginning of speech token string. |
|
audio_eos_token (`str`): |
|
End of speech token string. |
|
audio_pad_token_id (`int`): |
|
Padding token ID for audio channels. |
|
max_channels (`int`): |
|
Maximum number of quantization channels. |
|
input_sample_rate (`int`): |
|
Target sample rate for input audio. |
|
encoder_downsample_rate (`int`): |
|
Downsampling rate of the audio encoder. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer, |
|
feature_extractor: Optional = None, |
|
audio_tokenizer: Optional = None, |
|
*, |
|
chat_template: Optional[str], |
|
speech_token_range: list[int], |
|
audio_bos_token: str, |
|
audio_eos_token: str, |
|
audio_pad_token_id: int, |
|
max_channels: int, |
|
input_sample_rate: int, |
|
encoder_downsample_rate: int, |
|
) -> None: |
|
self.tokenizer = tokenizer |
|
self.feature_extractor = feature_extractor |
|
self.audio_tokenizer = audio_tokenizer |
|
self.chat_template = chat_template |
|
self.speech_token_range = speech_token_range |
|
self.audio_bos_token = audio_bos_token |
|
self.audio_eos_token = audio_eos_token |
|
self.audio_pad_token_id = audio_pad_token_id |
|
self.max_channels = max_channels |
|
self.input_sample_rate = input_sample_rate |
|
self.encoder_downsample_rate = encoder_downsample_rate |
|
|
|
def prepare_sample( |
|
self, |
|
item: dict[str, Any], |
|
*, |
|
apply_chat_template: Callable[[str, dict], str], |
|
use_normalize: bool = False, |
|
silence_duration: float = 0.0, |
|
**kwargs, |
|
) -> MossTTSDChatSample: |
|
""" |
|
Prepare a single sample from JSONL item into MossTTSDChatSample format. |
|
|
|
Args: |
|
item (`dict`): |
|
Dictionary containing the input data (text, prompt_audio, etc.). |
|
apply_chat_template (`callable`): |
|
Function to apply chat template formatting. |
|
use_normalize (`bool`, *optional*, defaults to `False`): |
|
Whether to apply text normalization. |
|
silence_duration (`float`, *optional*, defaults to `0.0`): |
|
Duration of silence to append to audio for encoder segmentation. |
|
**kwargs: |
|
Additional keyword arguments passed to chat template. |
|
|
|
Returns: |
|
`MossTTSDChatSample`: Processed sample with 2D input tensor and metadata. |
|
""" |
|
processed = self._process_jsonl_item(item) |
|
system_prompt = item.get("system_prompt") |
|
if isinstance(system_prompt, str): |
|
kwargs["system_prompt"] = system_prompt |
|
|
|
full_text = (processed["prompt_text"] or "") + processed["text"] |
|
original_full_text = full_text |
|
if use_normalize: |
|
full_text = self._normalize_text(full_text) |
|
final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>") |
|
|
|
|
|
wav = self._process_audio_data(processed["prompt_audio"], target_sample_rate=self.input_sample_rate) |
|
|
|
|
|
inputs_2d = self._build_inputs( |
|
text=final_text, |
|
audio_data=wav, |
|
apply_chat_template=apply_chat_template, |
|
silence_duration=silence_duration, |
|
**kwargs, |
|
) |
|
inputs_2d = self._shift_inputs(inputs_2d, pad_token_id=self.tokenizer.pad_token_id, max_channels=self.max_channels) |
|
|
|
meta = { |
|
"original_text": original_full_text, |
|
"normalized_text": self._normalize_text(original_full_text) if use_normalize else None, |
|
"final_text": final_text, |
|
"use_normalize": use_normalize, |
|
} |
|
ids_t = torch.tensor(inputs_2d, dtype=torch.long) |
|
return MossTTSDChatSample(input_ids_2d=ids_t, label_ids_2d=None, meta=meta) |
|
|
|
def collate( |
|
self, |
|
samples: list[MossTTSDChatSample], |
|
*, |
|
pad_token_id: int, |
|
audio_pad_token_id: int, |
|
) -> MossTTSDBatchInput: |
|
""" |
|
Collate multiple samples into a batch with proper padding. |
|
|
|
Args: |
|
samples (`List[MossTTSDChatSample]`): |
|
List of MossTTSDChatSample objects to collate. |
|
pad_token_id (`int`): |
|
Padding token ID for text tokens. |
|
audio_pad_token_id (`int`): |
|
Padding token ID for audio tokens. |
|
|
|
Returns: |
|
`MossTTSDBatchInput`: Batched input with padded tensors. |
|
""" |
|
assert is_torch_available(), "PyTorch is required for collation." |
|
ids_list = [s.input_ids_2d for s in samples] |
|
labels_list = [s.label_ids_2d for s in samples] |
|
|
|
C = ids_list[0].shape[1] |
|
max_len = max(x.shape[0] for x in ids_list) |
|
padded_ids, padded_labels, padded_attn = [], [], [] |
|
|
|
for ids, labels in zip(ids_list, labels_list): |
|
pad_len = max_len - ids.shape[0] |
|
pad_grid = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long) |
|
pad_grid[:, 0] = pad_token_id |
|
ids_padded = torch.cat([pad_grid, ids], dim=0) |
|
padded_ids.append(ids_padded) |
|
|
|
attn = torch.ones(ids.shape[0], dtype=torch.long) |
|
a_pad = torch.zeros(pad_len, dtype=torch.long) |
|
padded_attn.append(torch.cat([a_pad, attn], dim=0)) |
|
|
|
if labels is None: |
|
padded_labels.append(None) |
|
else: |
|
lab_pad = torch.full((pad_len, C), audio_pad_token_id, dtype=torch.long) |
|
lab_pad[:, 0] = -100 |
|
padded_labels.append(torch.cat([lab_pad, labels], dim=0)) |
|
|
|
input_ids = torch.stack(padded_ids) |
|
attention_mask = torch.stack(padded_attn) |
|
labels = torch.stack([l if l is not None else torch.full_like(input_ids[0], -100) for l in padded_labels]) \ |
|
if any(l is not None for l in padded_labels) else None |
|
|
|
return MossTTSDBatchInput(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
|
|
|
@staticmethod |
|
def _process_jsonl_item(item: dict[str, Any]) -> dict[str, Any]: |
|
""" |
|
Process a JSONL item to extract text and audio data. |
|
|
|
Supports both single-speaker and multi-speaker formats: |
|
- Single: {"prompt_audio": path, "prompt_text": text} |
|
- Multi: {"prompt_audio_speaker1": path1, "prompt_text_speaker1": text1, ...} |
|
|
|
Args: |
|
item: Dictionary containing the JSONL item data. |
|
|
|
Returns: |
|
Dictionary with extracted "text", "prompt_text", and "prompt_audio" fields. |
|
""" |
|
base_path = item.get("base_path", "") |
|
text = item.get("text", "") |
|
|
|
prompt_audio = None |
|
prompt_text = "" |
|
|
|
if "prompt_audio" in item and "prompt_text" in item: |
|
pa = item["prompt_audio"] |
|
if pa: |
|
prompt_audio = os.path.join(base_path, pa) if isinstance(pa, str) and base_path else pa |
|
prompt_text = item.get("prompt_text", "") |
|
else: |
|
pa1, pt1 = item.get("prompt_audio_speaker1", ""), item.get("prompt_text_speaker1", "") |
|
pa2, pt2 = item.get("prompt_audio_speaker2", ""), item.get("prompt_text_speaker2", "") |
|
has1 = (isinstance(pa1, str) and pa1) or isinstance(pa1, tuple) |
|
has2 = (isinstance(pa2, str) and pa2) or isinstance(pa2, tuple) |
|
if has1 or has2: |
|
spk1 = os.path.join(base_path, pa1) if isinstance(pa1, str) and base_path and pa1 else pa1 |
|
spk2 = os.path.join(base_path, pa2) if isinstance(pa2, str) and base_path and pa2 else pa2 |
|
prompt_audio = {"speaker1": spk1, "speaker2": spk2} |
|
tmp = "" |
|
if pt1: |
|
tmp += f"[S1]{pt1}" |
|
if pt2: |
|
tmp += f"[S2]{pt2}" |
|
prompt_text = tmp.strip() |
|
|
|
return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio} |
|
|
|
@staticmethod |
|
def _normalize_text(text: str) -> str: |
|
""" |
|
Normalize text by applying various transformations for TTS processing. |
|
|
|
Performs speaker tag conversion, punctuation normalization, laughter conversion, |
|
and other text cleaning operations suitable for speech synthesis. |
|
|
|
Args: |
|
text: Input text string to normalize. |
|
|
|
Returns: |
|
Normalized text string. |
|
""" |
|
text = re.sub(r"\[(\d+)\]", r"[S\1]", text) |
|
remove_chars = '【】《》()『』「」"-""~~' |
|
text = re.sub(r"\[(?!S\d+\])([^\]]*)\]", r"\1", text) |
|
segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " ")) |
|
out = [] |
|
for seg in segments: |
|
seg = seg.strip() |
|
if not seg: |
|
continue |
|
m = re.match(r"^(\[S\d+\])\s*(.*)", seg) |
|
tag, content = m.groups() if m else ("", seg) |
|
content = re.sub(f"[{re.escape(remove_chars)}]", "", content) |
|
content = re.sub(r"哈{2,}", "(笑)", content) |
|
content = re.sub(r"\b(ha(\s*ha)+)\b", "(laughs)", content, flags=re.IGNORECASE) |
|
content = content.replace("——", ",").replace("……", ",") |
|
trans = str.maketrans({"!": ",", "!": ",", ";": ",", ";": ",", ":": ",", ":": ",", "、": ",", "?": ",", "?": ","}) |
|
content = content.translate(trans).strip() |
|
if len(content) > 1: |
|
last = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) |
|
body = content[:-1].replace("。", ",") |
|
content = body + last |
|
out.append(f"{tag}{content}".strip()) |
|
return "".join(out) |
|
|
|
@staticmethod |
|
def _load_single_audio(audio_input: Union[str, tuple["torch.Tensor", int]]): |
|
""" |
|
Load audio from file path or tensor tuple. |
|
|
|
Args: |
|
audio_input: Either a file path string or a tuple of (tensor, sample_rate). |
|
|
|
Returns: |
|
Tuple of (audio_tensor, sample_rate). |
|
|
|
Raises: |
|
ValueError: If audio input format is unsupported. |
|
""" |
|
if isinstance(audio_input, tuple) and len(audio_input) == 2: |
|
return audio_input |
|
if isinstance(audio_input, str): |
|
try: |
|
return torchaudio.load(audio_input) |
|
except Exception: |
|
import soundfile as sf |
|
data, sr = sf.read(audio_input, always_2d=True) |
|
data_t = torch.from_numpy(np.transpose(data)) |
|
return data_t, int(sr) |
|
raise ValueError(f"Unsupported audio input format: {type(audio_input)}") |
|
|
|
@staticmethod |
|
def _resample(audio: "torch.Tensor", sr: int, target_sr: int) -> tuple["torch.Tensor", int]: |
|
""" |
|
Resample audio to target sample rate and convert to mono if needed. |
|
|
|
Args: |
|
audio: Input audio tensor with shape (channels, time). |
|
sr: Current sample rate. |
|
target_sr: Target sample rate. |
|
|
|
Returns: |
|
Tuple of (resampled_audio, target_sr) where audio is mono with shape (1, time). |
|
""" |
|
if sr != target_sr: |
|
audio = torchaudio.functional.resample(audio, sr, target_sr) |
|
if audio.shape[0] > 1: |
|
audio = audio.mean(dim=0, keepdim=True) |
|
if audio.ndim == 1: |
|
audio = audio.unsqueeze(0) |
|
return audio, target_sr |
|
|
|
@classmethod |
|
def _load_audio_data( |
|
cls, audio_input: Union[str, tuple["torch.Tensor", int]], target_sample_rate: int |
|
) -> tuple["torch.Tensor", int]: |
|
""" |
|
Load and resample audio data to target sample rate. |
|
|
|
Args: |
|
audio_input: Audio file path or tensor tuple. |
|
target_sample_rate: Target sample rate for resampling. |
|
|
|
Returns: |
|
Tuple of (audio_tensor, target_sample_rate). |
|
""" |
|
audio, sr = cls._load_single_audio(audio_input) |
|
return cls._resample(audio, sr, target_sample_rate) |
|
|
|
@classmethod |
|
def _merge_speaker_audios( |
|
cls, |
|
wav1: Union[str, tuple["torch.Tensor", int]], |
|
wav2: Union[str, tuple["torch.Tensor", int]], |
|
target_sample_rate: int, |
|
) -> "torch.Tensor": |
|
""" |
|
Merge two speaker audio inputs by concatenation. |
|
|
|
Args: |
|
wav1: Audio input for speaker 1. |
|
wav2: Audio input for speaker 2. |
|
target_sample_rate: Target sample rate for both audio inputs. |
|
|
|
Returns: |
|
Concatenated audio tensor. |
|
""" |
|
a1, _ = cls._load_audio_data(wav1, target_sample_rate) |
|
a2, _ = cls._load_audio_data(wav2, target_sample_rate) |
|
return torch.cat([a1, a2], dim=1) |
|
|
|
@classmethod |
|
def _process_audio_data( |
|
cls, prompt_audio: Optional[Union[str, dict[str, Any], tuple["torch.Tensor", int]]], target_sample_rate: int |
|
) -> Optional["torch.Tensor"]: |
|
""" |
|
Process audio data from various input formats. |
|
|
|
Handles single audio files, multi-speaker audio dictionaries, or None input. |
|
|
|
Args: |
|
prompt_audio: Audio input in various formats (path, dict, tensor tuple, or None). |
|
target_sample_rate: Target sample rate for processing. |
|
|
|
Returns: |
|
Processed audio tensor or None if no audio provided. |
|
""" |
|
if prompt_audio is None: |
|
return None |
|
if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio: |
|
return cls._merge_speaker_audios(prompt_audio["speaker1"], prompt_audio["speaker2"], target_sample_rate) |
|
wav, _ = cls._load_audio_data(prompt_audio, target_sample_rate) |
|
return wav |
|
|
|
def _build_inputs( |
|
self, |
|
text: str, |
|
audio_data: Optional["torch.Tensor"], |
|
apply_chat_template: Callable[[str, dict], str], |
|
silence_duration: float, |
|
**kwargs, |
|
) -> np.ndarray: |
|
""" |
|
Build input grid from text and optional audio data. |
|
|
|
Creates a TxC grid where column 0 contains text tokens and columns 1..C-1 contain |
|
quantized audio codebook tokens. Audio tokens are mapped to speech token range. |
|
|
|
Args: |
|
text: Input text string to process. |
|
audio_data: Optional audio tensor with shape (channels, time). |
|
apply_chat_template: Function to apply chat template formatting. |
|
silence_duration: Duration of silence to append for encoder segmentation. |
|
**kwargs: Additional arguments for chat template. |
|
|
|
Returns: |
|
NumPy array with shape (T, max_channels) containing the input grid. |
|
""" |
|
assert isinstance(text, str), "text must be a string" |
|
prompt = apply_chat_template(text, kwargs) |
|
|
|
text_ids = np.array(self.tokenizer.encode(prompt, add_special_tokens=False)) |
|
grid = np.full((text_ids.shape[0], self.max_channels), self.audio_pad_token_id, dtype=np.int64) |
|
grid[:, 0] = text_ids |
|
|
|
if audio_data is not None: |
|
silence_samples = int(max(0.0, silence_duration) * self.input_sample_rate) |
|
silence = torch.zeros(audio_data.shape[0], silence_samples, device=audio_data.device) |
|
wav = torch.cat([audio_data, silence], dim=1) |
|
|
|
feat = self.feature_extractor( |
|
wav, sampling_rate=self.input_sample_rate, return_attention_mask=True, return_tensors="pt" |
|
) |
|
with torch.no_grad(): |
|
enc = self.audio_tokenizer.encode(feat) |
|
|
|
audio_codes = enc["audio_codes"][:, 0].permute(1, 0).cpu().numpy() |
|
|
|
audio_codes[:, 0] = audio_codes[:, 0] + self.speech_token_range[0] |
|
grid = np.concatenate([grid, audio_codes], axis=0) |
|
|
|
|
|
silence_tokens = silence_duration * self.input_sample_rate / self.encoder_downsample_rate |
|
cut = math.floor(silence_tokens / 10) * 10 |
|
if cut > 0: |
|
grid = grid[:-cut] |
|
|
|
return grid |
|
|
|
@staticmethod |
|
def _shift_inputs(input_ids: np.ndarray, pad_token_id: int, max_channels: int) -> np.ndarray: |
|
""" |
|
Convert (T, C) grid to time-shifted multi-channel layout (preserving original implementation logic). |
|
|
|
Creates a shifted layout where new_len = T + C - 1, with column j shifted backwards by j positions. |
|
This enables the model to process multiple codebook channels with temporal alignment. |
|
|
|
Args: |
|
input_ids: Input grid with shape (T, C). |
|
pad_token_id: Padding token ID for text tokens. |
|
max_channels: Maximum number of channels. |
|
|
|
Returns: |
|
Shifted array with shape (T + max_channels - 1, max_channels). |
|
""" |
|
T, _ = input_ids.shape |
|
new_len = T + max_channels - 1 |
|
shifted = np.full((new_len, max_channels), fill_value=1024, dtype=np.int64) |
|
shifted[:, 0] = np.full(new_len, pad_token_id, dtype=np.int64) |
|
for j in range(max_channels): |
|
shifted[j : (T + j), j] = input_ids[:, j] |
|
return shifted |
|
|
|
|
|
class MossTTSDProcessor(ProcessorMixin): |
|
r""" |
|
Constructs a MOSS-TTSD processor which wraps a tokenizer, feature extractor, and audio tokenizer into a single |
|
processor. It provides unified text-speech processing capabilities while maintaining backward compatibility with |
|
previous API versions. |
|
|
|
[`MossTTSDProcessor`] offers all the functionalities of [`AutoTokenizer`], [`AutoFeatureExtractor`] and |
|
[`XYTokenizer`]. See the [`~MossTTSDProcessor.__call__`] and [`~MossTTSDProcessor.decode`] for more information. |
|
|
|
Args: |
|
tokenizer ([`AutoTokenizer`]): |
|
An instance of [`AutoTokenizer`]. The tokenizer is a required input. |
|
feature_extractor ([`AutoFeatureExtractor`]): |
|
An instance of [`AutoFeatureExtractor`]. The feature extractor is a required input. |
|
audio_tokenizer ([`XYTokenizer`]): |
|
An instance of [`XYTokenizer`]. The audio tokenizer is a required input. |
|
chat_template (`str`, *optional*): |
|
A template string for chat formatting when combining text and audio interactions. |
|
speech_token_range (`List[int]`, *optional*, defaults to `[151665, 152689]`): |
|
Token range [start, end] for mapping speech tokens. |
|
audio_bos_token (`str`, *optional*, defaults to `"<|begin_of_speech|>"`): |
|
Beginning of speech token string. |
|
audio_eos_token (`str`, *optional*, defaults to `"<|end_of_speech|>"`): |
|
End of speech token string. |
|
audio_pad_token_id (`int`, *optional*, defaults to `1024`): |
|
Padding token ID for audio channels. |
|
""" |
|
feature_extractor_class = "AutoFeatureExtractor" |
|
tokenizer_class = "AutoTokenizer" |
|
audio_tokenizer_class = "PreTrainedModel" |
|
|
|
def __init__( |
|
self, |
|
tokenizer, |
|
feature_extractor, |
|
audio_tokenizer, |
|
chat_template: Optional[str] = None, |
|
speech_token_range: Optional[list[int]] = None, |
|
audio_bos_token: str = "<|begin_of_speech|>", |
|
audio_eos_token: str = "<|end_of_speech|>", |
|
audio_pad_token_id: int = 1024, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, audio_tokenizer=audio_tokenizer, **kwargs) |
|
|
|
self.max_channels = (audio_tokenizer.quantizer.num_quantizers if audio_tokenizer else None) or 8 |
|
self.input_sample_rate = (getattr(audio_tokenizer, "config", None).input_sample_rate if audio_tokenizer else None) or 16000 |
|
self.output_sample_rate = (getattr(audio_tokenizer, "config", None).output_sample_rate if audio_tokenizer else None) or 16000 |
|
self.encoder_downsample_rate = (getattr(audio_tokenizer, "config", None).encoder_downsample_rate if audio_tokenizer else None) or 320 |
|
|
|
|
|
self.chat_template = getattr(tokenizer, "chat_template", None) or chat_template |
|
|
|
|
|
self.speech_token_range = ( |
|
getattr(tokenizer, "speech_token_range", None) or speech_token_range or [151665, 152689] |
|
) |
|
self.audio_bos_token = getattr(tokenizer, "audio_bos_token", None) or audio_bos_token |
|
self.audio_eos_token = getattr(tokenizer, "audio_eos_token", None) or audio_eos_token |
|
self.audio_pad_token_id = getattr(tokenizer, "audio_pad_token_id", None) or audio_pad_token_id |
|
|
|
|
|
self.sample_processor = MossTTSDSampleProcessor( |
|
tokenizer=self.tokenizer, |
|
feature_extractor=self.feature_extractor, |
|
audio_tokenizer=self.audio_tokenizer, |
|
chat_template=self.chat_template, |
|
speech_token_range=self.speech_token_range, |
|
audio_bos_token=self.audio_bos_token, |
|
audio_eos_token=self.audio_eos_token, |
|
audio_pad_token_id=self.audio_pad_token_id, |
|
max_channels=self.max_channels, |
|
input_sample_rate=self.input_sample_rate, |
|
encoder_downsample_rate=self.encoder_downsample_rate, |
|
) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], trust_remote_code=True, **kwargs): |
|
""" |
|
Instantiate a processor from a pretrained model. |
|
|
|
Args: |
|
pretrained_model_name_or_path (`str` or `os.PathLike`): |
|
The name of or path to the pretrained model. |
|
**kwargs: |
|
Additional keyword arguments passed to the respective component loaders. |
|
|
|
Returns: |
|
[`MossTTSDProcessor`]: A new instance of the processor. |
|
""" |
|
kwargs.pop("_from_auto") |
|
audio_tokenizer_path = kwargs.pop("codec_path", os.path.join(pretrained_model_name_or_path, "XY_Tokenizer")) |
|
assert isinstance(audio_tokenizer_path, str), f"Unsupported audio_tokenizer_path input format: {type(audio_tokenizer_path)}" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs) |
|
audio_tokenizer = AutoModel.from_pretrained(audio_tokenizer_path, trust_remote_code=trust_remote_code, **kwargs) |
|
|
|
return cls( |
|
tokenizer=tokenizer, |
|
feature_extractor=feature_extractor, |
|
audio_tokenizer=audio_tokenizer, |
|
**kwargs, |
|
) |
|
|
|
@classmethod |
|
def get_processor_dict( |
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs |
|
) -> tuple[dict[str, Any], dict[str, Any]]: |
|
proc_dict, rest = super().get_processor_dict(pretrained_model_name_or_path, **kwargs) |
|
if "audio_tokenizer" in rest: |
|
proc_dict["audio_tokenizer"] = rest.pop("audio_tokenizer") |
|
for key in ("speech_token_range", "audio_bos_token", "audio_eos_token", "audio_pad_token_id"): |
|
if key in rest: |
|
proc_dict[key] = rest.pop(key) |
|
return proc_dict, rest |
|
|
|
def __call__( |
|
self, |
|
data: Union[dict[str, Any], list[dict[str, Any]]], |
|
**kwargs: Unpack[MossTTSDProcessorKwargs], |
|
) -> BatchEncoding: |
|
""" |
|
Main method to prepare inputs for the model from structured data. |
|
|
|
This method forwards the `data` and `kwargs` arguments to prepare inputs for MOSS-TTSD model. Please refer to the |
|
docstring of the respective methods for more information. |
|
|
|
Args: |
|
data (`dict` or `list[dict]`): |
|
Single dictionary or list of dictionaries containing input data. Expected keys include 'text', |
|
'prompt_text', 'prompt_audio', etc. |
|
**kwargs (`MossTTSDProcessorKwargs`): |
|
Additional processing arguments. |
|
|
|
Returns: |
|
[`BatchEncoding`]: Processed inputs ready for model consumption. |
|
""" |
|
if isinstance(data, dict): |
|
data = [data] |
|
|
|
out_kwargs = self._merge_kwargs(MossTTSDProcessorKwargs, **kwargs) |
|
text_kwargs = out_kwargs["text_kwargs"] |
|
audio_kwargs = out_kwargs["audio_kwargs"] |
|
common_kwargs = out_kwargs["common_kwargs"] |
|
|
|
return_tensors = common_kwargs.get("return_tensors", "pt") |
|
padding = common_kwargs.get("padding", True) |
|
use_normalize = common_kwargs.get("use_normalize", False) |
|
|
|
pad_token_id = int(text_kwargs.get("pad_token_id", self.tokenizer.pad_token_id or 0)) |
|
max_channels = int(audio_kwargs.get("max_channels", self.max_channels)) |
|
audio_pad_token_id = int(audio_kwargs.get("audio_pad_token_id", self.audio_pad_token_id)) |
|
silence_duration = float(audio_kwargs.get("silence_duration", 0.0)) |
|
|
|
def _apply_chat_template(text: str, extra: dict) -> str: |
|
return self.apply_chat_template(conversation=None, text=text, **extra) |
|
|
|
samples: list[MossTTSDChatSample] = [] |
|
for item in data: |
|
sample = self.sample_processor.prepare_sample( |
|
item, |
|
apply_chat_template=_apply_chat_template, |
|
use_normalize=use_normalize, |
|
silence_duration=silence_duration, |
|
) |
|
|
|
if sample.input_ids_2d.shape[1] != max_channels: |
|
|
|
T, C = sample.input_ids_2d.shape |
|
if C > max_channels: |
|
sample.input_ids_2d = sample.input_ids_2d[:, :max_channels] |
|
else: |
|
pad = torch.full((T, max_channels - C), audio_pad_token_id, dtype=torch.long) |
|
sample.input_ids_2d = torch.cat([sample.input_ids_2d, pad], dim=1) |
|
samples.append(sample) |
|
|
|
if not padding: |
|
raise NotImplementedError("Unpadded batches are not supported yet.") |
|
|
|
batch = self.sample_processor.collate( |
|
samples, |
|
pad_token_id=pad_token_id, |
|
audio_pad_token_id=audio_pad_token_id, |
|
) |
|
|
|
inputs = asdict(batch) |
|
inputs = {k: v for k, v in inputs.items() if v is not None} |
|
return BatchEncoding(inputs, tensor_type=return_tensors) |
|
|
|
def shifting_outputs( |
|
self, |
|
output_ids: "torch.Tensor", |
|
speech_token_range: list[int], |
|
max_channels: int = 8, |
|
) -> "torch.Tensor": |
|
""" |
|
Restore time-shifted layout to per-timestep C-channel arrangement and reverse-offset first codebook. |
|
|
|
Converts the time-shifted multi-channel output back to standard (batch, time, channels) format |
|
and maps the first codebook tokens back to their original space by subtracting the speech token offset. |
|
|
|
Args: |
|
output_ids: Time-shifted output tensor. |
|
speech_token_range: Speech token range for reverse mapping. |
|
max_channels: Number of codebook channels. |
|
|
|
Returns: |
|
Restored tensor with shape (batch, seq_len, max_channels). |
|
""" |
|
seq_len = output_ids.shape[1] - max_channels + 1 |
|
speech_ids = torch.full((output_ids.shape[0], seq_len, max_channels), 0, dtype=output_ids.dtype, device=output_ids.device) |
|
for j in range(max_channels): |
|
speech_ids[..., j] = output_ids[:, j : seq_len + j, j] |
|
if j == 0: |
|
speech_ids[..., j] = speech_ids[..., j] - speech_token_range[0] |
|
return speech_ids |
|
|
|
def _find_max_valid_positions(self, data: "torch.Tensor", invalid_value: int = 1024): |
|
""" |
|
Locate continuous valid audio segment intervals in each sequence (all non-text channels valid simultaneously). |
|
|
|
Identifies contiguous spans where all audio channels (columns 1+) contain valid tokens |
|
(not the invalid_value padding token). |
|
|
|
Args: |
|
data: Input tensor with shape (batch, time, channels). |
|
invalid_value: Token ID considered as invalid/padding. |
|
|
|
Returns: |
|
List of lists containing valid audio segments for each sequence in the batch. |
|
""" |
|
mask = torch.all(data[:, :, 1:] != invalid_value, dim=2) |
|
valid_indices = torch.where(mask) |
|
result = [[] for _ in range(len(data))] |
|
if valid_indices[0].numel() == 0: |
|
return result |
|
grouped = [] |
|
group_ids = [] |
|
for i, seq_no in enumerate(valid_indices[0]): |
|
pos = valid_indices[1][i] |
|
if not group_ids or seq_no > group_ids[-1]: |
|
group_ids.append(seq_no) |
|
grouped.append([[pos, pos + 1]]) |
|
elif pos == grouped[-1][-1][-1]: |
|
grouped[-1][-1][-1] += 1 |
|
else: |
|
grouped[-1].append([pos, pos + 1]) |
|
for gid, spans in zip(group_ids, grouped): |
|
for s, e in spans: |
|
result[gid].append(data[gid, s:e, :]) |
|
return result |
|
|
|
def batch_decode(self, token_ids: "torch.Tensor", *args, **kwargs): |
|
""" |
|
Decode a batch of token sequences into text and audio outputs. |
|
|
|
This method forwards the `token_ids` and `kwargs` arguments to decode text and audio outputs from the model. |
|
Please refer to the docstring of the respective methods for more information. |
|
|
|
Args: |
|
token_ids (`torch.Tensor`): |
|
Token tensor with shape (batch, time, channels). |
|
*args: |
|
Additional arguments passed to tokenizer.batch_decode. |
|
**kwargs: |
|
Additional keyword arguments passed to tokenizer.batch_decode. |
|
|
|
Returns: |
|
`tuple`: Tuple of (text_list, audio_list) where text_list contains decoded text strings and audio_list |
|
contains decoded audio arrays for each sequence. |
|
""" |
|
assert token_ids.ndim == 3 and token_ids.shape[2] == self.max_channels |
|
text = self.tokenizer.batch_decode(token_ids[:, :, 0], *args, **kwargs) |
|
normal = self.shifting_outputs(token_ids, self.speech_token_range, self.max_channels) |
|
audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id) |
|
decode_audio = [] |
|
for seq_frags in audio_frags: |
|
if len(seq_frags): |
|
frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in seq_frags], dim=1) |
|
decode_audio.append(self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"]) |
|
else: |
|
decode_audio.append([]) |
|
return text, decode_audio |
|
|
|
def decode(self, token_ids: "torch.Tensor", *args, **kwargs) -> MossTTSDResponse: |
|
""" |
|
Decode a single sequence of token IDs into text and audio. |
|
|
|
This method forwards the `token_ids` and `kwargs` arguments to decode a single sequence. Please refer to the |
|
docstring of the respective methods for more information. |
|
|
|
Args: |
|
token_ids (`torch.Tensor`): |
|
Token tensor with shape (time, channels). |
|
*args: |
|
Additional arguments passed to tokenizer.decode. |
|
**kwargs: |
|
Additional keyword arguments passed to tokenizer.decode. |
|
|
|
Returns: |
|
[`MossTTSDResponse`]: Response object containing generated text, audio, and sampling rate. |
|
""" |
|
assert token_ids.ndim == 2 and token_ids.shape[1] == self.max_channels |
|
text = self.tokenizer.decode(token_ids[:, 0].squeeze(-1), *args, **kwargs) |
|
normal = self.shifting_outputs(token_ids.unsqueeze(0), self.speech_token_range, self.max_channels) |
|
audio_frags = self._find_max_valid_positions(normal, self.audio_pad_token_id)[0] |
|
if len(audio_frags): |
|
frag = torch.cat([f.permute(1, 0).unsqueeze(1) for f in audio_frags], dim=1) |
|
audio = self.audio_tokenizer.decode(frag, overlap_seconds=10)["audio_values"] |
|
else: |
|
audio = None |
|
return MossTTSDResponse( |
|
audio=None if audio is None else audio.detach().cpu().numpy(), |
|
generated_text=text, |
|
sampling_rate=self.output_sample_rate, |
|
) |
|
|
|
def save_audio(self, audios, output_dir="output", prefix="audio"): |
|
""" |
|
Save multiple audio fragments to files. |
|
|
|
Args: |
|
audios: List of audio data fragments from batch_decode |
|
output_dir (str): Directory to save audio files |
|
prefix (str): Prefix for audio filenames |
|
""" |
|
if not is_torchaudio_available(): |
|
raise ImportError("Please install `torchaudio` to save audio files.") |
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
for i, data in enumerate(audios): |
|
for j, fragment in enumerate(data): |
|
filename = f"{output_dir}/{prefix}_{i}_{j}.wav" |
|
torchaudio.save(filename, fragment.cpu(), self.output_sample_rate) |
|
|
|
|
|
__all__ = ["MossTTSDProcessor"] |
|
|