Spaces:
Running
on
T4
Running
on
T4
File size: 11,591 Bytes
ad798d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
import os
import sys
import json
import librosa
from huggingface_hub import snapshot_download
import torch
import torch.nn as nn
from typing import Optional
import safetensors
from transformers import AutoTokenizer
from utils.util import load_config
from models.tts.tadicodec.inference_tadicodec import TaDiCodecPipline
from models.tts.llm_tts.mgm import MGMT2S
from models.tts.llm_tts.chat_template import gen_chat_prompt_for_tts
class MGMInferencePipeline(nn.Module):
"""
MGM TTS inference pipeline that integrates TaDiCodec and MGM models
Uses diffusion-based generation with mask-guided modeling
"""
def __init__(
self,
tadicodec_path: str,
mgm_path: str,
device: torch.device,
):
super().__init__()
self.device = device
self.mgm_path = mgm_path
# Load TaDiCodec pipeline
self.tadicodec = TaDiCodecPipline.from_pretrained(
ckpt_dir=tadicodec_path, device=device
)
# Load tokenizer directly from pretrained
self.tokenizer = AutoTokenizer.from_pretrained(
mgm_path,
trust_remote_code=True,
)
config_path = os.path.join(mgm_path, "config.json")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found at {config_path}")
self.cfg = load_config(config_path)
# Extract MGM config from the loaded config
mgm_config = self.cfg.model.mgmt2s
if not mgm_config:
raise ValueError("MGM config not found in config.json")
# Load MGM model with config - using the same pattern as llm_infer_eval.py
self.mgm = MGMT2S(
hidden_size=mgm_config.hidden_size,
num_layers=mgm_config.num_layers,
num_heads=mgm_config.num_heads,
cfg_scale=mgm_config.cfg_scale,
cond_codebook_size=mgm_config.cond_codebook_size,
cond_dim=mgm_config.cond_dim,
phone_vocab_size=mgm_config.phone_vocab_size,
)
# Load model weights
model_path = os.path.join(mgm_path, "model.safetensors")
if os.path.exists(model_path):
safetensors.torch.load_model(self.mgm, model_path, strict=True)
else:
# Try loading from the directory directly
safetensors.torch.load_model(self.mgm, mgm_path, strict=True)
self.mgm.to(device)
self.mgm.eval()
def tensor_to_audio_string(self, tensor):
"""Convert tensor to audio string format"""
if isinstance(tensor, list) and isinstance(tensor[0], list):
values = tensor[0]
else:
values = tensor[0].tolist() if hasattr(tensor, "tolist") else tensor[0]
result = "<|start_of_audio|>"
for value in values:
result += f"<|audio_{value}|>"
return result
def extract_audio_ids(self, text):
"""Extract audio IDs from string containing audio tokens"""
import re
pattern = r"<\|audio_(\d+)\|>"
audio_ids = re.findall(pattern, text)
return [int(id) for id in audio_ids]
@classmethod
def from_pretrained(
cls,
model_id: str = None,
tadicodec_path: str = None,
mgm_path: str = None,
device: Optional[torch.device] = None,
auto_download: bool = True,
):
"""
Create pipeline from pretrained models
Args:
model_id: Hugging Face model ID for the MGM model (e.g., "amphion/TaDiCodec-TTS-MGM")
tadicodec_path: Path to TaDiCodec model or Hugging Face model ID (defaults to "amphion/TaDiCodec")
mgm_path: Path to MGM model directory or Hugging Face model ID (overrides model_id if provided)
device: Device to run on
auto_download: Whether to automatically download models from Hugging Face if not found locally
Returns:
MGMInferencePipeline instance
"""
resolved_device = (
device
if device is not None
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
# Set default paths if not provided
if tadicodec_path is None:
tadicodec_path = "amphion/TaDiCodec"
if mgm_path is None:
if model_id is not None:
mgm_path = model_id
else:
mgm_path = "./ckpt/TaDiCodec-TTS-MGM"
# Handle TaDiCodec path
resolved_tadicodec_path = cls._resolve_model_path(
tadicodec_path, auto_download=auto_download, model_type="tadicodec"
)
# Handle MGM path
resolved_mgm_path = cls._resolve_model_path(
mgm_path, auto_download=auto_download, model_type="mgm"
)
return cls(
tadicodec_path=resolved_tadicodec_path,
mgm_path=resolved_mgm_path,
device=resolved_device,
)
@staticmethod
def _resolve_model_path(
model_path: str, auto_download: bool = True, model_type: str = "mgm"
) -> str:
"""
Resolve model path, downloading from Hugging Face if necessary
Args:
model_path: Local path or Hugging Face model ID
auto_download: Whether to auto-download from HF
model_type: Type of model ("mgm" or "tadicodec")
Returns:
Resolved local path
"""
# If it's already a local path and exists, return as is
if os.path.exists(model_path):
return model_path
# If it looks like a Hugging Face model ID (contains '/')
if "/" in model_path and auto_download:
print(f"Downloading {model_type} model from Hugging Face: {model_path}")
try:
# Download to cache directory
cache_dir = os.path.join(
os.path.expanduser("~"), ".cache", "huggingface", "hub"
)
downloaded_path = snapshot_download(
repo_id=model_path,
cache_dir=cache_dir,
local_dir_use_symlinks=False,
)
print(
f"Successfully downloaded {model_type} model to: {downloaded_path}"
)
return downloaded_path
except Exception as e:
print(f"Failed to download {model_type} model from Hugging Face: {e}")
raise ValueError(
f"Could not download {model_type} model from {model_path}"
)
# If it's a local path that doesn't exist
if not os.path.exists(model_path):
if auto_download:
raise ValueError(
f"Model path does not exist: {model_path}. Set auto_download=True to download from Hugging Face."
)
else:
raise FileNotFoundError(f"Model path does not exist: {model_path}")
return model_path
@torch.no_grad()
def __call__(
self,
text: str,
prompt_text: Optional[str] = None,
prompt_speech_path: Optional[str] = None,
n_timesteps_mgm: int = 25,
n_timesteps: int = 25,
target_len: Optional[int] = None,
return_code: bool = False,
):
"""
Perform MGM TTS inference
Args:
text: Target text to synthesize
prompt_text: Prompt text for conditioning
prompt_speech_path: Path to prompt audio file
n_timesteps_mgm: Number of diffusion timesteps for MGM
n_timesteps: Number of diffusion timesteps
target_len: Target length for audio generation
return_code: Whether to return audio codes instead of audio
Returns:
Generated audio array or audio codes
"""
# Get prompt audio codes
if prompt_speech_path:
prompt_speech_code = self.tadicodec(
speech_path=prompt_speech_path, return_code=True, text=""
)
else:
raise ValueError("prompt_speech_path is required")
# Convert prompt codes to tensor
prompt_codes = torch.tensor(prompt_speech_code).to(self.device)
prompt_len = prompt_codes.shape[1]
# Tokenize text for phone conditioning
input_text = gen_chat_prompt_for_tts(
prompt_text + " " + text,
"phi-3" if "phi" in self.cfg.preprocess.tokenizer_path else "qwen2",
)
##### debug #####
print("input_text: ", input_text)
##### debug #####
text_token_ids = self.tokenizer.encode(input_text)
text_token_ids = torch.tensor(text_token_ids).unsqueeze(0).to(self.device)
# Estimate target length based on text length
frame_rate = getattr(self.cfg.preprocess, "frame_rate", 6.25)
if target_len is None:
# If no target_len, estimate based on prompt speech length and text ratio
prompt_text_len = len(prompt_text.encode("utf-8"))
target_text_len = len(text.encode("utf-8"))
prompt_speech_len = librosa.get_duration(filename=prompt_speech_path)
target_speech_len = prompt_speech_len * target_text_len / prompt_text_len
target_len = int(target_speech_len * frame_rate)
else:
# If target_len is provided, use it directly
target_len = int(target_len * frame_rate)
##### debug #####
print(f"Prompt length: {prompt_len}, Target length: {target_len}")
print(f"Text: {text}")
print(f"Prompt text: {prompt_text}")
##### debug #####
# Generate audio codes using MGM reverse diffusion
generated_codes = self.mgm.reverse_diffusion(
prompt=prompt_codes,
target_len=target_len,
phone_id=text_token_ids,
n_timesteps=n_timesteps_mgm,
cfg=1.5,
rescale_cfg=0.75,
)
print(f"Generated codes shape: {generated_codes.shape}")
combine_codes = torch.cat([prompt_codes, generated_codes], dim=1)
if return_code:
return combine_codes
# Decode audio using TaDiCodec
prompt_mel = self.tadicodec.extract_mel_feature(prompt_speech_path)
text_token_ids = self.tadicodec.tokenize_text(text, prompt_text)
rec_mel = self.tadicodec.decode(
indices=combine_codes,
text_token_ids=text_token_ids,
prompt_mel=prompt_mel,
n_timesteps=n_timesteps,
)
rec_audio = (
self.tadicodec.vocoder_model(rec_mel.transpose(1, 2))
.detach()
.cpu()
.numpy()[0][0]
)
return rec_audio
# Usage example
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create pipeline
pipeline = MGMInferencePipeline.from_pretrained(
tadicodec_path="./ckpt/TaDiCodec",
mgm_path="./ckpt/TaDiCodec-TTS-MGM",
device=device,
)
# Inference on single sample
audio = pipeline(
text="但是 to those who 知道 her well, it was a 标志 of her unwavering 决心 and spirit.",
prompt_text="In short, we embarked on a mission to make America great again, for all Americans.",
prompt_speech_path="./use_examples/test_audio/trump_0.wav",
)
# Save audio
import soundfile as sf
sf.write("./use_examples/test_audio/mgm_tts_output.wav", audio, 24000)
|