Spaces:
Sleeping
Sleeping
import argparse | |
import codecs | |
import re | |
import tempfile | |
from pathlib import Path | |
import logging | |
import numpy as np | |
import soundfile as sf | |
import tomli | |
import torch | |
import torchaudio | |
from tqdm import tqdm | |
from einops import rearrange | |
from pydub import AudioSegment, silence | |
from transformers import pipeline | |
from huggingface_hub import login | |
from cached_path import cached_path | |
import matplotlib.pyplot as plt # Needed for save_spectrogram | |
# --- Import Model Architectures --- | |
# !! Ensure these models are defined in your project's 'model' module !! | |
try: | |
from model import UNetT, DiT | |
except ImportError: | |
print("Warning: Could not import UNetT, DiT from 'model'. Using placeholders.") | |
# Placeholder classes if import fails (script might not work correctly) | |
class MockModel: | |
def __init__(self, *args, **kwargs): pass | |
def to(self, device): return self | |
def eval(self): pass | |
def sample(self, *args, **kwargs): | |
duration = kwargs.get('duration', 500); mel_dim = 100 | |
return torch.randn(1, duration, mel_dim), None | |
def device(self): return torch.device("cpu") | |
DiT = MockModel | |
UNetT = MockModel | |
# --- Import/Define Utility Functions --- | |
from tokenizers import Tokenizer | |
from phonemizer import phonemize | |
# --- Functions copied/adapted from app.py --- | |
# Function to load vocoder (from app.py context, may need adjustment) | |
def load_vocoder(device='cpu'): | |
"""Loads the Vocos vocoder.""" | |
print("Loading Vocos vocoder (charactr/vocos-mel-24khz)...") | |
try: | |
# Ensure vocos library is installed: pip install vocos | |
from vocos import Vocos | |
# Determine torch dtype based on device for potential efficiency | |
# Note: Vocos might internally cast, but being explicit can help. | |
# Using float32 as a safe default unless on CUDA where float16 might work. | |
vocos_dtype = torch.float16 if str(device) == 'cuda' else torch.float32 | |
vocos_model = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device) | |
# Cast to appropriate dtype if needed, although Vocos might handle this. | |
# vocos_model = vocos_model.to(dtype=vocos_dtype) # Optional casting | |
vocos_model.eval() | |
print("Vocos vocoder loaded successfully.") | |
return vocos_model | |
except ImportError: | |
print("Error: 'vocos' library not found. Please install it: pip install vocos") | |
raise | |
except Exception as e: | |
print(f"Error loading Vocos model: {e}") | |
raise | |
# Function to remove silence from edges (from app.py) | |
def remove_silence_edges(aseg): | |
"""Removes silence from the beginning and end of an AudioSegment.""" | |
print("Removing silence from audio edges...") | |
start_trim = silence.detect_leading_silence(aseg) | |
end_trim = silence.detect_leading_silence(aseg.reverse()) | |
duration = len(aseg) | |
trimmed_aseg = aseg[start_trim:duration-end_trim] | |
print(f"Removed {start_trim}ms from start, {end_trim}ms from end.") | |
return trimmed_aseg | |
# Function to save spectrogram (from app.py) | |
def save_spectrogram(spectrogram, file_path): | |
"""Saves a spectrogram visualization to a file.""" | |
if spectrogram is None: | |
print("Spectrogram data is None, cannot save.") | |
return | |
try: | |
print(f"Saving spectrogram to {file_path}...") | |
plt.figure(figsize=(10, 4)) | |
plt.imshow(spectrogram, aspect='auto', origin='lower', cmap='viridis') | |
plt.colorbar(label='Mel power') | |
plt.xlabel('Frames') | |
plt.ylabel('Mel bins') | |
plt.title('Generated Mel Spectrogram') | |
plt.tight_layout() | |
plt.savefig(file_path) | |
plt.close() # Close the figure to free memory | |
print("Spectrogram saved.") | |
except Exception as e: | |
print(f"Error saving spectrogram: {e}") | |
# Helper function to load checkpoint (from app.py, slightly modified for CLI) | |
def load_checkpoint(model, ckpt_path, device, use_ema=False): | |
"""Loads model weights from a checkpoint file (.pt or .safetensors).""" | |
print(f"Loading checkpoint from {ckpt_path}...") | |
try: | |
if ckpt_path.endswith(".safetensors"): | |
# Ensure safetensors is installed: pip install safetensors | |
from safetensors.torch import load_file | |
state_dict = load_file(ckpt_path, device="cpu") | |
elif ckpt_path.endswith(".pt"): | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
else: | |
raise ValueError(f"Unsupported checkpoint format: {ckpt_path}. Must be .pt or .safetensors") | |
# Standardize state_dict format (e.g., remove 'state_dict' key if present) | |
if "state_dict" in state_dict: | |
state_dict = state_dict["state_dict"] | |
# Handle EMA weights | |
ema_key_prefix = "ema_model." # Adjust if your EMA keys have a different prefix | |
final_state_dict = {} | |
has_ema = any(k.startswith(ema_key_prefix) for k in state_dict.keys()) | |
if use_ema: | |
if has_ema: | |
print("Attempting to load EMA weights.") | |
ema_state_dict = {k[len(ema_key_prefix):]: v for k, v in state_dict.items() if k.startswith(ema_key_prefix)} | |
if ema_state_dict: | |
final_state_dict = ema_state_dict | |
print("Using EMA weights.") | |
else: | |
# This case shouldn't happen if has_ema is true, but as a safeguard: | |
print("Warning: EMA weights requested but none found starting with prefix. Using regular weights.") | |
final_state_dict = {k: v for k, v in state_dict.items() if not k.startswith(ema_key_prefix)} | |
else: | |
print("Warning: EMA weights requested but no keys found with EMA prefix. Using regular weights.") | |
final_state_dict = state_dict # Use the original dict if no EMA keys exist | |
else: | |
print("Loading non-EMA weights.") | |
# Filter out EMA weights if they exist and we explicitly don't want them | |
final_state_dict = {k: v for k, v in state_dict.items() if not k.startswith(ema_key_prefix)} | |
# Load into model, handling potential 'module.' prefix from DDP | |
model_state_dict = model.state_dict() | |
processed_state_dict = {} | |
for k, v in final_state_dict.items(): | |
if k.startswith("module."): | |
k_proc = k[len("module."):] | |
else: | |
k_proc = k | |
if k_proc in model_state_dict: | |
if model_state_dict[k_proc].shape == v.shape: | |
processed_state_dict[k_proc] = v | |
else: | |
print(f"Warning: Shape mismatch for key {k_proc}. Checkpoint: {v.shape}, Model: {model_state_dict[k_proc].shape}. Skipping.") | |
# else: # Optional: Log unexpected keys | |
# print(f"Warning: Key {k_proc} from checkpoint not found in model. Skipping.") | |
missing_keys, unexpected_keys = model.load_state_dict(processed_state_dict, strict=False) | |
if missing_keys: | |
print(f"Warning: Missing keys in model not found in checkpoint: {missing_keys}") | |
if unexpected_keys: | |
# This should ideally be empty if we filter correctly, but good to check. | |
print(f"Warning: Unexpected keys (should not happen with filtering): {unexpected_keys}") | |
print(f"Checkpoint loaded successfully from {ckpt_path}") | |
except FileNotFoundError: | |
print(f"Error: Checkpoint file not found at {ckpt_path}") | |
raise | |
except Exception as e: | |
print(f"Error loading checkpoint from {ckpt_path}: {e}") | |
raise # Re-raise the exception | |
model.eval() | |
return model.to(device) | |
# Primary model loading function (from app.py) | |
def load_custom(model_cls, model_cfg, ckpt_path: str, vocab_size: int, device='cpu', use_ema=True): | |
"""Loads a custom TTS model (DiT or UNetT) with specified config and checkpoint.""" | |
ckpt_path = ckpt_path.strip() | |
if ckpt_path.startswith("hf://"): | |
print(f"Downloading checkpoint from Hugging Face Hub: {ckpt_path}") | |
try: | |
ckpt_path = str(cached_path(ckpt_path)) | |
print(f"Checkpoint downloaded to: {ckpt_path}") | |
except Exception as e: | |
print(f"Error downloading checkpoint {ckpt_path}: {e}") | |
raise | |
if not Path(ckpt_path).exists(): | |
raise FileNotFoundError(f"Checkpoint file not found: {ckpt_path}") | |
# Ensure necessary config keys are present (add defaults if missing) | |
if 'mel_dim' not in model_cfg: | |
model_cfg['mel_dim'] = 100 # Default mel channels | |
print(f"Warning: 'mel_dim' not in model_cfg, defaulting to {model_cfg['mel_dim']}") | |
if 'text_num_embeds' not in model_cfg: | |
model_cfg['text_num_embeds'] = vocab_size | |
print(f"Setting 'text_num_embeds' in model_cfg to vocab size: {vocab_size}") | |
print(f"Instantiating model: {model_cls.__name__} with config: {model_cfg}") | |
try: | |
model = model_cls(**model_cfg).to(device) # Instantiate the model | |
except Exception as e: | |
print(f"Error instantiating model {model_cls.__name__} with config {model_cfg}: {e}") | |
raise | |
# Load weights using the helper function | |
model = load_checkpoint(model, ckpt_path, device, use_ema=use_ema) | |
model.eval() # Ensure model is in evaluation mode | |
return model | |
# Text chunking function (from app.py) | |
def chunk_text(text, max_chars): | |
""" | |
Splits the input text into chunks based on punctuation and length limits. | |
(Copied from previous answer, assumed correct) | |
""" | |
if not isinstance(text, str): | |
print("Warning: Input to chunk_text is not a string. Returning empty list.") | |
return [] | |
if max_chars > 135: | |
print(f"Warning: Calculated max_chars ({max_chars}) > 135. Capping at 135.") | |
max_chars = 135 | |
if max_chars < 50: | |
print(f"Warning: Calculated max_chars ({max_chars}) < 50. Setting to 50.") | |
max_chars = 50 | |
split_after_space_chars = max_chars + int(max_chars * 0.33) | |
chunks = [] | |
current_chunk = "" | |
# Split the text into sentences based on punctuation followed by whitespace | |
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])\s*", text) # Added \s* after CJK punc | |
for sentence in sentences: | |
sentence = sentence.strip() | |
if not sentence: | |
continue | |
# Estimate potential length increase due to space | |
estimated_len = len(current_chunk) + len(sentence) + (1 if current_chunk else 0) | |
if estimated_len <= max_chars: | |
current_chunk += (" " + sentence) if current_chunk else sentence | |
else: | |
# Process the current_chunk if adding the new sentence exceeds max_chars | |
while len(current_chunk) > split_after_space_chars: | |
split_index = current_chunk.rfind(" ", 0, split_after_space_chars) | |
if split_index == -1: split_index = split_after_space_chars | |
chunks.append(current_chunk[:split_index].strip()) | |
current_chunk = current_chunk[split_index:].strip() | |
if current_chunk: | |
chunks.append(current_chunk) | |
# Start new chunk, handle if sentence itself is too long | |
while len(sentence) > split_after_space_chars: | |
split_index = sentence.rfind(" ", 0, split_after_space_chars) | |
if split_index == -1: split_index = split_after_space_chars | |
chunks.append(sentence[:split_index].strip()) | |
sentence = sentence[split_index:].strip() | |
current_chunk = sentence | |
# Handle the last chunk | |
while len(current_chunk) > split_after_space_chars: | |
split_index = current_chunk.rfind(" ", 0, split_after_space_chars) | |
if split_index == -1: split_index = split_after_space_chars | |
chunks.append(current_chunk[:split_index].strip()) | |
current_chunk = current_chunk[split_index:].strip() | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
return [c for c in chunks if c] # Filter empty chunks | |
# Text to IPA function (from app.py) | |
def text_to_ipa(text, language): | |
"""Converts text to IPA using phonemizer with espeak backend.""" | |
if not isinstance(text, str) or not text.strip(): | |
print(f"Warning: Invalid input text for IPA conversion: {text}") | |
return "" # Return empty string for invalid input | |
try: | |
# Ensure phonemizer is installed: pip install phonemizer | |
# Ensure espeak-ng is installed: sudo apt-get install espeak-ng (or equivalent) | |
ipa_text = phonemize( | |
text, | |
language=language, | |
backend='espeak', | |
strip=False, # Keep punctuation | |
preserve_punctuation=True, | |
with_stress=True, | |
language_switch='remove-flags', # Use this instead of regex removal | |
njobs=1 # Set njobs=1 for potentially better stability/simpler debugging | |
) | |
# Specific removals (might be redundant with remove-flags, but kept for consistency) | |
ipa_text = re.sub(r'tʃˈaɪniːzlˈe̞tə', '', ipa_text) | |
ipa_text = re.sub(r'tʃˈaɪniːzɭˈetə', '', ipa_text) | |
ipa_text = re.sub(r'dʒˈapəniːzlˈe̞tə', '', ipa_text) | |
ipa_text = re.sub(r'dʒˈapəniːzɭˈetə', '', ipa_text) | |
ipa_text = ipa_text.strip() | |
# Replace multiple spaces with single space | |
ipa_text = re.sub(r'\s+', ' ', ipa_text) | |
print(f"Text: '{text}' | Lang: {language} | IPA: '{ipa_text}'") | |
return ipa_text | |
except ImportError: | |
print("Error: 'phonemizer' library not found. Please install it: pip install phonemizer") | |
raise | |
except Exception as e: | |
# Check if it's an espeak error (often happens if language is unsupported) | |
if "espeak" in str(e).lower(): | |
print(f"Error: Espeak backend failed for language '{language}'. Is the language code valid and espeak-ng installed/supporting it?") | |
print(f" Original error: {e}") | |
else: | |
print(f"Error phonemizing text: '{text}' with language '{language}'. Error: {e}") | |
# Decide how to handle error | |
raise ValueError(f"Phonemization failed for '{text}' ({language})") from e | |
# --- End of functions from app.py --- | |
# --- Argument Parser Setup --- | |
# (Parser definition remains the same as previous refactored version) | |
parser = argparse.ArgumentParser( | |
prog="python3 inference-cli.py", | |
description="Commandline interface for F5/E2 TTS.", | |
) | |
parser.add_argument( | |
"-c", "--config", type=str, default="inference-cli.toml", | |
help="Path to configuration file (TOML format). Default: inference-cli.toml" | |
) | |
# --- Arguments overriding config or providing inputs --- | |
parser.add_argument( "--ckpt_path", type=str, default=None, help="Path or Hub ID (hf://...) to the TTS model checkpoint (.pt/.safetensors). Overrides config.") | |
parser.add_argument( "--ref_audio", type=str, default=None, help="Path to the reference audio file (<10s recommended). Overrides config.") | |
parser.add_argument( "--ref_text", type=str, default=None, help="Reference text. If omitted, Whisper transcription is used. Overrides config.") | |
parser.add_argument( "--gen_text", type=str, default=None, help="Text to synthesize. Overrides config.") | |
parser.add_argument( "--gen_file", type=str, default=None, help="File containing text to synthesize (overrides --gen_text and config).") | |
parser.add_argument( "--output_dir", type=str, default=None, help="Directory to save output audio and spectrogram. Overrides config.") | |
parser.add_argument( "--output_name", type=str, default="out", help="Base name for output files (e.g., 'my_speech' -> my_speech.wav, my_speech.png). Default: out.") | |
# --- Parameter Arguments --- | |
parser.add_argument( "--ref_language", type=str, default=None, help="Language code for reference text phonemization (e.g., 'en-us', 'pl', 'de'). Overrides config.") | |
parser.add_argument( "--language", type=str, default=None, help="Language code for generated text phonemization (e.g., 'en-us', 'pl', 'de'). Overrides config.") | |
parser.add_argument( "--speed", type=float, default=None, help="Speech speed multiplier. Overrides config.") | |
parser.add_argument( "--nfe", type=int, default=None, help="Number of function evaluations (sampling steps). Overrides config.") | |
parser.add_argument( "--cfg", type=float, default=None, help="Classifier-Free Guidance strength. Overrides config.") | |
parser.add_argument( "--sway", type=float, default=None, help="Sway sampling coefficient. Overrides config.") | |
parser.add_argument( "--cross_fade", type=float, default=None, help="Cross-fade duration between batches (seconds). Overrides config.") | |
parser.add_argument( "--remove_silence", action=argparse.BooleanOptionalAction, default=None, help="Enable/disable final silence removal. Overrides config.") | |
parser.add_argument( "--hf_token", type=str, default=None, help="Hugging Face API token (for downloading private models/checkpoints).") | |
parser.add_argument( "--tokenizer_path", type=str, default=None, help="Path to the tokenizer.json file. Overrides config.") | |
parser.add_argument( "--device", type=str, default=None, help="Device to use ('cuda', 'cpu', 'mps'). Auto-detects if not set.") | |
parser.add_argument( "--dtype", type=str, default=None, help="Data type ('float16', 'bfloat16', 'float32'). Auto-selects if not set.") | |
args = parser.parse_args() | |
# --- Load Configuration --- | |
config = {} | |
if Path(args.config).exists(): | |
try: | |
with open(args.config, "rb") as f: | |
config = tomli.load(f) | |
print(f"Loaded configuration from {args.config}") | |
except Exception as e: | |
print(f"Warning: Could not load config file {args.config}. Error: {e}") | |
else: | |
print(f"Warning: Config file {args.config} not found. Using defaults and CLI args.") | |
# --- Determine Parameters (CLI > Config > Defaults) --- | |
# (Parameter determination remains the same) | |
ckpt_path = args.ckpt_path or config.get("ckpt_path", "hf://Gregniuki/F5-tts_English_German_Polish/multi3/model_900000.pt") | |
ref_audio_path = args.ref_audio or config.get("ref_audio") | |
ref_text = args.ref_text if args.ref_text is not None else config.get("ref_text", "") | |
gen_text = args.gen_text or config.get("gen_text") | |
gen_file = args.gen_file or config.get("gen_file") | |
output_dir = Path(args.output_dir or config.get("output_dir", ".")) | |
output_name = args.output_name or config.get("output_name", "out") | |
ref_language = args.ref_language or config.get("ref_language", "en-us") | |
language = args.language or config.get("language", "en-us") | |
speed = args.speed if args.speed is not None else config.get("speed", 1.0) | |
nfe_step = args.nfe if args.nfe is not None else config.get("nfe", 32) | |
cfg_strength = args.cfg if args.cfg is not None else config.get("cfg", 2.0) | |
sway_sampling_coef = args.sway if args.sway is not None else config.get("sway", -1.0) | |
cross_fade_duration = args.cross_fade if args.cross_fade is not None else config.get("cross_fade", 0.15) | |
remove_silence_flag = args.remove_silence if args.remove_silence is not None else config.get("remove_silence", False) | |
hf_token = args.hf_token or config.get("hf_token") | |
tokenizer_path = args.tokenizer_path or config.get("tokenizer_path", "data/Emilia_ZH_EN_pinyin/tokenizer.json") | |
# --- Validate Required Arguments --- | |
if not ckpt_path: raise ValueError("Missing required argument/config: --ckpt_path") | |
if not ref_audio_path: raise ValueError("Missing required argument/config: --ref_audio") | |
if not gen_text and not gen_file: raise ValueError("Missing required argument/config: --gen_text or --gen_file") | |
# --- Read gen_text from file if provided --- | |
if gen_file: | |
try: | |
with codecs.open(gen_file, "r", "utf-8") as f: gen_text = f.read() | |
print(f"Loaded generation text from {gen_file}") | |
except Exception as e: raise ValueError(f"Error reading generation text file {gen_file}: {e}") | |
# --- Setup Device and Dtype --- | |
# (Device/Dtype setup remains the same) | |
cli_device = args.device or config.get("device") | |
if cli_device: | |
device = torch.device(cli_device) | |
else: | |
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
cli_dtype = args.dtype or config.get("dtype") | |
if cli_dtype: | |
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32} | |
if cli_dtype in dtype_map: dtype = dtype_map[cli_dtype] | |
else: raise ValueError(f"Unsupported dtype: {cli_dtype}") | |
else: | |
if device.type == "cuda": dtype = torch.float16 | |
elif device.type == "cpu" and hasattr(torch.backends, 'cpu') and torch.backends.cpu.supports_bfloat16: dtype = torch.bfloat16 | |
else: dtype = torch.float32 | |
print(f"Using device: {device}, dtype: {dtype}") | |
# --- Hugging Face Login --- | |
if hf_token: | |
print("Logging in to Hugging Face Hub...") | |
try: | |
login(token=hf_token) | |
print("Logged in successfully.") | |
except Exception as e: | |
print(f"Warning: Hugging Face login failed: {e}") | |
# --- Create Output Directory --- | |
output_dir.mkdir(parents=True, exist_ok=True) | |
wave_path = output_dir / f"{output_name}.wav" | |
spectrogram_path = output_dir / f"{output_name}.png" | |
# --- Load Models and Tokenizer --- | |
print("Loading Tokenizer...") | |
try: | |
if not Path(tokenizer_path).exists(): | |
raise FileNotFoundError(f"Tokenizer file not found: {tokenizer_path}") | |
tokenizer = Tokenizer.from_file(tokenizer_path) | |
vocab_size = tokenizer.get_vocab_size() | |
print(f"Tokenizer loaded successfully. Vocab size: {vocab_size}") | |
except Exception as e: | |
raise ValueError(f"Error loading tokenizer from {tokenizer_path}: {e}") | |
print("Loading Vocoder...") | |
# Pass device to load_vocoder | |
vocos = load_vocoder(device=device) # Already includes .to(device).eval() | |
print("Loading ASR Model (Whisper)...") | |
try: | |
whisper_dtype = torch.float16 if device.type == 'cuda' else torch.float32 | |
# Reduce default batch_size for Whisper CLI use | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model="openai/whisper-large-v3-turbo", | |
torch_dtype=whisper_dtype, | |
device=device, | |
model_kwargs={"attn_implementation": "sdpa"} # Use SDPA if available | |
) | |
print("Whisper model loaded.") | |
except Exception as e: | |
print(f"Warning: Could not load Whisper ASR model: {e}. Transcription will not be available.") | |
pipe = None | |
print("Loading TTS Model...") | |
# --- Determine Model Class and Config --- | |
# Example configs (ensure they match your actual model requirements) | |
F5TTS_model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) # Add mel_dim/text_num_embeds if needed by class | |
# Heuristic to determine model class (improve if needed) | |
if "E2TTS" in ckpt_path or "UNetT" in ckpt_path: | |
model_cls = UNetT | |
model_cfg = E2TTS_model_cfg | |
print(f"Assuming E2-TTS (UNetT) architecture for {ckpt_path}.") | |
elif "F5TTS" in ckpt_path or "DiT" in ckpt_path: | |
model_cls = DiT | |
model_cfg = F5TTS_model_cfg | |
print(f"Assuming F5-TTS (DiT) architecture for {ckpt_path}.") | |
else: | |
# Default or raise error if model type cannot be inferred | |
print(f"Warning: Cannot infer model type from '{ckpt_path}'. Defaulting to DiT/F5TTS.") | |
model_cls = DiT | |
model_cfg = F5TTS_model_cfg | |
try: | |
# Pass vocab_size needed by load_custom | |
ema_model = load_custom(model_cls, model_cfg, ckpt_path, vocab_size=vocab_size, device=device, use_ema=True) | |
# Ensure model is using the target runtime dtype | |
ema_model = ema_model.to(dtype=dtype) | |
print(f"TTS Model loaded successfully ({model_cls.__name__}).") | |
except Exception as e: | |
print(f"Critical Error: Failed to load TTS model from {ckpt_path}: {e}") | |
raise | |
# --- Settings from app.py --- | |
target_sample_rate = 24000 | |
n_mel_channels = model_cfg.get('mel_dim', 100) # Use mel_dim from config if available | |
hop_length = 256 | |
target_rms = 0.1 | |
# --- Main Inference Logic --- | |
def infer_batch(ref_audio_tuple, ref_text_ipa, gen_text_ipa_batches, | |
ema_model, vocos, tokenizer, | |
remove_silence_post, cross_fade_duration, | |
nfe_step, cfg_strength, sway_sampling_coef, speed, | |
target_sample_rate, hop_length, target_rms, device, dtype): | |
""" | |
Generates audio batches based on reference and text inputs. | |
(Function body remains the same as previous refactored version) | |
""" | |
audio, sr = ref_audio_tuple | |
audio = audio.to(device, dtype=dtype) | |
# Preprocess reference audio (resample, RMS norm) | |
if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) | |
current_rms = torch.sqrt(torch.mean(torch.square(audio))) | |
rms_applied_factor = 1.0 # Track scaling factor applied to ref | |
if current_rms < target_rms and current_rms > 1e-5: # Add safety check for near-silent audio | |
print(f"Reference audio RMS ({current_rms:.3f}) below target ({target_rms}). Normalizing.") | |
rms_applied_factor = target_rms / current_rms | |
audio = audio * rms_applied_factor | |
elif current_rms <= 1e-5: | |
print("Warning: Reference audio is near silent. Skipping RMS normalization.") | |
else: | |
print(f"Reference audio RMS ({current_rms:.3f}) >= target ({target_rms}). No normalization.") | |
if sr != target_sample_rate: | |
print(f"Resampling reference audio from {sr} Hz to {target_sample_rate} Hz.") | |
resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device) | |
audio = resampler(audio) | |
ref_audio_len_frames = audio.shape[-1] // hop_length | |
print(f"Reference audio length: {audio.shape[-1]/target_sample_rate:.2f}s ({ref_audio_len_frames} frames)") | |
generated_waves = [] | |
spectrograms = [] | |
progress_bar = tqdm(gen_text_ipa_batches, desc="Generating Batches") | |
for i, gen_text_ipa in enumerate(progress_bar): | |
progress_bar.set_postfix({"Batch": f"{i+1}/{len(gen_text_ipa_batches)}"}) | |
# Combine reference and generated IPA text | |
combined_ipa_text = ref_text_ipa + " " + gen_text_ipa | |
# print(f"Batch {i+1} Combined IPA: {combined_ipa_text}") # Debug | |
# Tokenize | |
try: | |
# Tokenizer expects single string or list of strings | |
encoding = tokenizer.encode(combined_ipa_text) | |
tokens = encoding.ids | |
token_str = encoding.tokens # For logging/debug | |
# --- Model Input Formatting --- | |
# Check how your specific model's `sample` method expects the 'text' input. | |
# Option 1 (like app.py): String of space-separated tokens | |
# token_input_string = ' '.join(map(str, token_str)) | |
# final_text_list = [token_input_string] | |
# Option 2: List of token IDs (might be more common) | |
# final_text_list = [tokens] # List containing the list/tensor of IDs | |
# Option 3: Tensor of token IDs (check model docs) | |
# Assuming model expects Option 1 based on app.py: | |
token_input_string = ' '.join(map(str, token_str)) | |
final_text_list = [token_input_string] | |
# print(f"Batch {i+1} Input Text List for Model: {final_text_list}") | |
except Exception as e: | |
print(f"Error tokenizing batch {i+1}: '{combined_ipa_text}'. Error: {e}") | |
continue | |
# Calculate duration | |
ref_ipa_len = len(ref_text_ipa) | |
gen_ipa_len = len(gen_text_ipa) | |
if ref_ipa_len == 0: ref_ipa_len = 1 # Avoid division by zero | |
duration_frames = ref_audio_len_frames + int(((ref_audio_len_frames / ref_ipa_len) * gen_ipa_len) / speed) | |
min_duration_frames = max(10, target_sample_rate // hop_length // 4) # Shorter min duration (e.g. 0.25s) | |
duration_frames = max(min_duration_frames, duration_frames) | |
max_duration_frames = 40 * target_sample_rate // hop_length # Increase max duration slightly? | |
if duration_frames > max_duration_frames: | |
print(f"Warning: Calculated duration {duration_frames} frames exceeds max {max_duration_frames}. Capping.") | |
duration_frames = max_duration_frames | |
# print(f"Batch {i+1}: Duration={duration_frames} frames") | |
# Inference | |
try: | |
with torch.inference_mode(): | |
cond_audio = audio.to(ema_model.device, dtype=dtype) # Match model device/dtype | |
# print(f"Model device: {ema_model.device}, Cond audio device: {cond_audio.device}, dtype: {cond_audio.dtype}") | |
generated_mel, _ = ema_model.sample( | |
cond=cond_audio, | |
text=final_text_list, # Pass formatted text input | |
duration=duration_frames, | |
steps=nfe_step, | |
cfg_strength=cfg_strength, | |
sway_sampling_coef=sway_sampling_coef, | |
) | |
# Process generated mel | |
generated_mel = generated_mel.to(device, dtype=dtype) # Back to main device/dtype | |
generated_mel = generated_mel[:, ref_audio_len_frames:, :] | |
generated_mel_spec = rearrange(generated_mel, "1 n d -> 1 d n") | |
# Vocoding | |
# Vocos usually expects float32 | |
vocos_input_mel = generated_mel_spec.to(vocos.device, dtype=torch.float32) | |
generated_wave = vocos.decode(vocos_input_mel) | |
generated_wave = generated_wave.to(device, dtype=torch.float32) | |
# Adjust RMS (Scale generated audio by the same factor applied to reference) | |
generated_wave = generated_wave * rms_applied_factor | |
# Convert to numpy | |
generated_wave_np = generated_wave.squeeze().cpu().numpy() | |
generated_waves.append(generated_wave_np) | |
spectrograms.append(generated_mel_spec[0].cpu().to(torch.float32).numpy()) | |
except Exception as e: | |
logging.exception(f"Error during inference/processing for batch {i+1}:") # Log traceback | |
print(f"Error details: {e}") | |
continue | |
if not generated_waves: | |
print("No audio waves were generated.") | |
return None, None | |
# Combine batches | |
print(f"Combining {len(generated_waves)} generated batches...") | |
if cross_fade_duration <= 0 or len(generated_waves) == 1: | |
final_wave = np.concatenate(generated_waves) | |
else: | |
# (Cross-fading logic remains the same) | |
final_wave = generated_waves[0] | |
for i in range(1, len(generated_waves)): | |
prev_wave = final_wave; next_wave = generated_waves[i] | |
cf_samples = min(int(cross_fade_duration * target_sample_rate), len(prev_wave), len(next_wave)) | |
if cf_samples <= 0: final_wave = np.concatenate([prev_wave, next_wave]); continue | |
p_olap = prev_wave[-cf_samples:]; n_olap = next_wave[:cf_samples] | |
f_out = np.linspace(1, 0, cf_samples, dtype=p_olap.dtype); f_in = np.linspace(0, 1, cf_samples, dtype=n_olap.dtype) | |
cf_olap = p_olap * f_out + n_olap * f_in | |
final_wave = np.concatenate([prev_wave[:-cf_samples], cf_olap, next_wave[cf_samples:]]) | |
print(f"Applied cross-fade of {cross_fade_duration:.2f}s between batches.") | |
# Optional: Remove silence post-combination | |
if remove_silence_post: | |
print("Removing silence from final output...") | |
try: | |
final_wave_float32 = final_wave.astype(np.float32) | |
with tempfile.NamedTemporaryFile(delete=True, suffix=".wav") as tmp_wav: | |
sf.write(tmp_wav.name, final_wave_float32, target_sample_rate) | |
aseg = AudioSegment.from_file(tmp_wav.name) | |
non_silent_segs = silence.split_on_silence( | |
aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500 | |
) | |
if not non_silent_segs: | |
print("Warning: Silence removal resulted in empty audio. Keeping original.") | |
else: | |
non_silent_wave = sum(non_silent_segs, AudioSegment.silent(duration=0)) | |
non_silent_wave.export(tmp_wav.name, format="wav") | |
final_wave_tensor, _ = torchaudio.load(tmp_wav.name) | |
final_wave = final_wave_tensor.squeeze().cpu().numpy() | |
print("Silence removal applied.") | |
except Exception as e: | |
print(f"Warning: Failed to remove silence: {e}. Using original.") | |
# Combine spectrograms | |
print("Combining spectrograms...") | |
try: | |
if spectrograms: | |
combined_spectrogram = np.concatenate(spectrograms, axis=1) | |
else: | |
combined_spectrogram = None | |
except ValueError as e: | |
print(f"Warning: Could not concatenate spectrograms: {e}. Skipping.") | |
combined_spectrogram = None | |
return final_wave, combined_spectrogram | |
def main_infer(ref_audio_orig_path, ref_text_input, gen_text_full, | |
ema_model, vocos, tokenizer, pipe_asr, # Loaded models/utils | |
ref_language, language, # Languages | |
speed, nfe_step, cfg_strength, sway_sampling_coef, # Sampling params | |
remove_silence_flag, cross_fade_duration, # Postprocessing | |
target_sample_rate, hop_length, target_rms, # Audio params | |
device, dtype): # System params | |
""" | |
Main inference function coordinating preprocessing, batching, and generation. | |
(Function body remains the same as previous refactored version) | |
""" | |
print(f"Starting inference for text: '{gen_text_full[:100]}...'") | |
# --- Reference Audio Preprocessing --- | |
print("Processing reference audio...") | |
processed_ref_path = None | |
try: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_ref_wav: | |
processed_ref_path = temp_ref_wav.name # Store path for potential use | |
aseg = AudioSegment.from_file(ref_audio_orig_path) | |
print(f"Original ref duration: {len(aseg)/1000:.2f}s") | |
# Edge silence removal + padding | |
aseg = remove_silence_edges(aseg) | |
aseg += AudioSegment.silent(duration=150) | |
# Split/recombine on silence | |
non_silent_segs = silence.split_on_silence( | |
aseg, min_silence_len=700, silence_thresh=-50, keep_silence=700 | |
) | |
if non_silent_segs: | |
aseg = sum(non_silent_segs, AudioSegment.silent(duration=0)) # Use sum for conciseness | |
else: | |
print("Warning: Silence splitting/recombining resulted in empty audio. Using edge-trimmed.") | |
# Clip to 10s | |
max_ref_duration_ms = 10000 | |
if len(aseg) > max_ref_duration_ms: | |
print(f"Reference audio exceeds {max_ref_duration_ms/1000}s. Clipping...") | |
aseg = aseg[:max_ref_duration_ms] | |
aseg.export(processed_ref_path, format="wav") | |
print(f"Processed ref duration: {len(aseg)/1000:.2f}s. Saved to temp file: {processed_ref_path}") | |
# Load processed audio tensor | |
ref_audio_tensor, sr_ref = torchaudio.load(processed_ref_path) | |
except Exception as e: | |
print(f"Error processing reference audio {ref_audio_orig_path}: {e}") | |
if processed_ref_path and Path(processed_ref_path).exists(): | |
Path(processed_ref_path).unlink() # Clean up temp file on error | |
raise | |
# --- Reference Text Handling --- | |
ref_text_processed = "" | |
if not ref_text_input or ref_text_input.strip() == "": | |
print("No reference text provided. Transcribing reference audio...") | |
if pipe_asr is None: | |
raise ValueError("Whisper ASR model not loaded. Cannot transcribe. Please provide --ref_text.") | |
if not processed_ref_path: | |
raise ValueError("Processed reference audio path is missing for transcription.") | |
try: | |
# Ensure Whisper input dtype matches its loaded dtype | |
whisper_input_dtype = pipe_asr.model.dtype | |
# Load audio specifically for Whisper if dtypes differ significantly | |
# Or rely on pipeline handling. Assuming pipeline handles it for now. | |
print(f"Transcribing: {processed_ref_path}") | |
transcription_result = pipe_asr( | |
processed_ref_path, | |
chunk_length_s=15, | |
batch_size=8, # Smaller batch size for CLI | |
generate_kwargs={"task": "transcribe", "language": None}, # Whisper language detection | |
return_timestamps=False, | |
) | |
ref_text_processed = transcription_result["text"].strip() | |
print(f"Transcription finished: '{ref_text_processed}'") | |
if not ref_text_processed: | |
print("Warning: Transcription resulted in empty text. Using placeholder.") | |
ref_text_processed = "Reference audio" | |
except Exception as e: | |
logging.exception("Error during transcription:") | |
raise ValueError("Transcription failed. Please provide --ref_text.") | |
else: | |
print("Using provided reference text.") | |
ref_text_processed = ref_text_input | |
# Clean up the temporary processed reference audio file | |
if processed_ref_path and Path(processed_ref_path).exists(): | |
try: | |
Path(processed_ref_path).unlink() | |
# print(f"Cleaned up temp ref file: {processed_ref_path}") # Debug | |
except OSError as e: | |
print(f"Warning: Could not delete temp ref file {processed_ref_path}: {e}") | |
# Ensure reference text ends with ". " | |
if not ref_text_processed.endswith(". "): | |
ref_text_processed = ref_text_processed.rstrip('. ') + ". " # More robust way | |
print(f"Final Reference Text: '{ref_text_processed}'") | |
# --- Phonemize Reference Text --- | |
print(f"Phonemizing reference text with language: {ref_language}") | |
ref_text_ipa = text_to_ipa(ref_text_processed, language=ref_language) | |
if not ref_text_ipa: raise ValueError("Reference text phonemization failed.") | |
# --- Chunk and Phonemize Generation Text --- | |
ref_audio_duration_sec = ref_audio_tensor.shape[-1] / sr_ref if sr_ref > 0 else 1.0 | |
if ref_audio_duration_sec <= 0: ref_audio_duration_sec = 1.0 | |
chars_per_sec = len(ref_text_processed.encode('utf-8')) / ref_audio_duration_sec if ref_audio_duration_sec > 0 else 10.0 | |
if chars_per_sec <= 0: chars_per_sec = 10.0 | |
target_chunk_duration_sec = max(5.0, 20.0 - ref_audio_duration_sec) | |
max_chars = int(chars_per_sec * target_chunk_duration_sec) | |
print(f"Ref duration: {ref_audio_duration_sec:.2f}s => Calculated max_chars/batch: {max_chars}") | |
gen_text_batches_plain = chunk_text(gen_text_full, max_chars=max_chars) | |
if not gen_text_batches_plain: raise ValueError("Text chunking resulted in zero batches.") | |
print(f"Split generation text into {len(gen_text_batches_plain)} batches.") | |
print(f"Phonemizing generation text batches with language: {language}") | |
gen_text_ipa_batches = [] | |
for i, batch_text in enumerate(gen_text_batches_plain): | |
# print(f" Phonemizing batch {i+1}/{len(gen_text_batches_plain)}...") # Verbose | |
batch_ipa = text_to_ipa(batch_text, language=language) | |
if batch_ipa: gen_text_ipa_batches.append(batch_ipa) | |
else: print(f"Warning: Skipping batch {i+1} due to phonemization failure.") | |
if not gen_text_ipa_batches: raise ValueError("Phonemization failed for all generation text batches.") | |
# --- Run Batched Inference --- | |
print(f"Starting batch inference process ({len(gen_text_ipa_batches)} batches)...") | |
final_wave, combined_spectrogram = infer_batch( | |
(ref_audio_tensor, sr_ref), ref_text_ipa, gen_text_ipa_batches, | |
ema_model, vocos, tokenizer, | |
remove_silence_flag, cross_fade_duration, | |
nfe_step, cfg_strength, sway_sampling_coef, speed, | |
target_sample_rate, hop_length, target_rms, | |
device, dtype | |
) | |
return final_wave, combined_spectrogram | |
# --- Execution --- | |
if __name__ == "__main__": | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
try: | |
final_wave_np, combined_spectrogram_np = main_infer( | |
ref_audio_path, ref_text, gen_text, | |
ema_model, vocos, tokenizer, pipe, | |
ref_language, language, | |
speed, nfe_step, cfg_strength, sway_sampling_coef, | |
remove_silence_flag, cross_fade_duration, | |
target_sample_rate, hop_length, target_rms, | |
device, dtype | |
) | |
# --- Save Outputs --- | |
output_saved = False | |
if final_wave_np is not None and len(final_wave_np) > 0: | |
print(f"Saving final audio ({len(final_wave_np)/target_sample_rate:.2f}s) to {wave_path}...") | |
final_wave_float32 = final_wave_np.astype(np.float32) # Ensure float32 for sf | |
sf.write(str(wave_path), final_wave_float32, target_sample_rate) | |
print("Audio saved successfully.") | |
output_saved = True | |
else: | |
print("Inference did not produce a valid audio wave.") | |
if combined_spectrogram_np is not None: | |
print(f"Saving combined spectrogram to {spectrogram_path}...") | |
save_spectrogram(combined_spectrogram_np, str(spectrogram_path)) | |
print("Spectrogram saved successfully.") | |
output_saved = True | |
# else: # No need to print if spectrogram was None | |
# print("Spectrogram generation failed or was skipped.") | |
if not output_saved: | |
print("No output files were generated.") | |
except FileNotFoundError as e: | |
logging.error(f"File not found: {e}") | |
print(f"\nError: A required file was not found. Please check paths. Details: {e}") | |
exit(1) | |
except ValueError as e: | |
logging.error(f"Value error: {e}") | |
print(f"\nError: An invalid value or configuration was encountered. Details: {e}") | |
exit(1) | |
except Exception as e: | |
logging.exception("An unexpected error occurred during inference:") # Log traceback | |
print(f"\nAn unexpected error occurred: {e}") | |
exit(1) | |
print("\nInference completed.") |