# Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Example Usage: see README.md """ import argparse import json import os import random import sys import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime import numpy as np import onnxruntime import s3tokenizer import torch import torch.distributed as dist import torchaudio import torchaudio.compliance.kaldi as kaldi from torch.utils.data import DataLoader, Dataset, DistributedSampler from tqdm import tqdm from flashcosyvoice.config import Config, CosyVoice2LLMConfig, SamplingParams from flashcosyvoice.cosyvoice2 import CosyVoice2 from flashcosyvoice.utils.audio import mel_spectrogram def set_all_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def save_file_async( wav, prompt_speech_tokens, generated_speech_tokens, info, timing_stats ): """Save audio asynchronously.""" try: os.makedirs(os.path.dirname(info['wav']), exist_ok=True) if wav is not None: wav = wav.cpu() torchaudio.save(info['wav'], wav, 24000) duration = wav.shape[-1] / 24000.0 rtf = ((timing_stats['dataloader_time'] + timing_stats['model_inference_time']) / timing_stats['batch_size']) / duration timing_stats['rtf'] = rtf else: duration = 0.0 info['timing_stats'] = timing_stats info['prompt_speech_tokens'] = prompt_speech_tokens info['generated_speech_tokens'] = generated_speech_tokens with open(f"{info['wav'].replace('.wav', '.json')}", "w") as f: json.dump(info, f, ensure_ascii=False, indent=4) return duration except Exception as e: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}") return 0.0 class AudioDataset(Dataset): def __init__(self, text_norm, text_tokenizer, data_list, model_config: Config): self.datas = [] self.text_norm = text_norm self.model_config = model_config """Example data_list: ``` {"key": "uttid_1", "prompt_text": "你好,我是小明。", "text": "你好,我是小红。", "prompt_wav": "/mnt/data/audio/00000000.wav", "wav": "/mnt/data/audio_synthetic/uttid_1.wav"} {"key": "uttid_2", "prompt_text": "你好,我是小红。", "text": "你好,我是小明。", "prompt_wav": "/mnt/data/audio/00000001.wav", "wav": "/mnt/data/audio_synthetic/uttid_2.wav"} ``` Note: - `key` is the key of this sample. - `prompt_text` is the text used for prompt. - `text` is the text used for generating real audio. - `prompt_wav` is the audio used for prompt. - `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script). """ missing = 0 with open(data_list, 'r', encoding='utf-8') as f: lines = f.readlines() total_lines = len(lines) if torch.distributed.get_node_local_rank() == 0: iterator = tqdm(lines, desc='Loading data') else: iterator = lines for line in iterator: data = json.loads(line.strip()) valid = True for k in ['key', 'prompt_text', 'text', 'prompt_wav']: if k not in data: valid = False break if data[k] is None: valid = False break if not os.path.exists(data['prompt_wav']): valid = False if valid: self.datas.append(data) else: missing += 1 if torch.distributed.get_node_local_rank() == 0: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f'[{timestamp}] - [INFO] - Loaded {total_lines} lines, found {missing} missing lines, total valid lines == {len(self.datas)}.') self.text_tokenizer = text_tokenizer option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) def __len__(self): return len(self.datas) def __getitem__(self, idx): data = self.datas[idx] try: # 1. feature for s3tokenizer audio = s3tokenizer.load_audio(data['prompt_wav'], sr=16000) # [T] log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] # 2. feature for speaker embedding spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) spk_emb = self.spk_model.run( None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} )[0].flatten().tolist() # 3. feature for flow audio, sample_rate = torchaudio.load(data['prompt_wav'], backend='soundfile') audio = audio.mean(dim=0, keepdim=True) # [1, T] if sample_rate != 24000: audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] mel_len = mel.shape[0] # 4. feature for llm if self.text_norm is not None: prompt_texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['prompt_text'].strip()))["sentences"]] prompt_text = ''.join(prompt_texts) texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['text'].strip()))["sentences"]] text = ''.join(texts) else: prompt_text = data['prompt_text'] text = data['text'] prompt_text_ids = self.text_tokenizer.encode(prompt_text) prompt_text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in prompt_text_ids] text_ids = self.text_tokenizer.encode(text) text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in text_ids] item = { "prompt_text_tokens": prompt_text_ids, "text_tokens": text_ids, "spk_emb": spk_emb, "mel": mel, "mel_len": mel_len, "log_mel": log_mel, "info": data, "min_tokens": len(text_ids) * self.model_config.min_token_text_ratio, "max_tokens": len(text_ids) * self.model_config.max_token_text_ratio, } except Exception as e: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [WARNING] - Error processing data item {data.get('key', idx)}: {e}") return None return item def collate_fn(batch): prompt_mels_for_llm = [item["log_mel"] for item in batch if item is not None] prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_mels_for_llm) # [B, num_mels=128, T] prompt_text_tokens_for_llm = [item["prompt_text_tokens"] for item in batch if item is not None] text_tokens_for_llm = [item["text_tokens"] for item in batch if item is not None] prompt_mels_for_flow = [item["mel"] for item in batch if item is not None] prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] prompt_mels_lens_for_flow = [item["mel_len"] for item in batch if item is not None] prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) spk_emb_for_flow = [item["spk_emb"] for item in batch if item is not None] spk_emb_for_flow = torch.tensor(spk_emb_for_flow) sampling_params = [SamplingParams(min_tokens=item["min_tokens"], max_tokens=item["max_tokens"], use_ras=True) for item in batch if item is not None] infos = [item["info"] for item in batch if item is not None] return { "prompt_mels_for_llm": prompt_mels_for_llm, "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm, "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm, "text_tokens_for_llm": text_tokens_for_llm, "prompt_mels_for_flow": prompt_mels_for_flow, "prompt_mels_lens_for_flow": prompt_mels_lens_for_flow, "spk_emb_for_flow": spk_emb_for_flow, "sampling_params": sampling_params, "infos": infos, } def init_distributed(): world_size = int(os.environ.get('WORLD_SIZE', 1)) local_rank = int(os.environ.get('LOCAL_RANK', 0)) rank = int(os.environ.get('RANK', 0)) timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f'[{timestamp}] - [INFO] - Inference on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}') torch.cuda.set_device(local_rank) dist.init_process_group("nccl") return world_size, local_rank, rank def get_args(): parser = argparse.ArgumentParser(description='FlashCosyVoice') parser.add_argument('--model_path', required=True, type=str, help='model path') parser.add_argument('--data_list', required=True, type=str, help='data list') parser.add_argument('--batch_size_dataloader', required=True, type=int, help='batch size (per-device) for dataloading') parser.add_argument('--batch_size_flow', required=True, type=int, help='batch size (per-device) for flow-matching') parser.add_argument('--num_workers', type=int, default=4, help='workers for dataloader') parser.add_argument('--prefetch', type=int, default=5, help='prefetch for dataloader') parser.add_argument('--enable_tn', action='store_true', help='enable text normalization') parser.add_argument('--only_llm', action='store_true', help='only generate speech tokens from llm') parser.add_argument('--fp16_flow', action='store_true', help='enable fp16 flow') parser.add_argument('--seed', type=int, default=1986, help='random seed for generation') args = parser.parse_args() return args def main(): args = get_args() if args.enable_tn: # Check python version, if == 3.10, use ttsfrd if sys.version_info.major == 3 and sys.version_info.minor == 10: # Check if ttsfrd is installed try: import ttsfrd from cosyvoice_ttsfrd import get_resource_path except ImportError as e: raise ImportError("ttsfrd is not installed, please install it first, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for installation guide.") from e text_norm = ttsfrd.TtsFrontendEngine() text_norm.initialize(get_resource_path()) text_norm.set_lang_type('pinyinvg') else: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [WARNING] - Only python 3.10 is supported for ttsfrd, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for more info. Setting enable_tn to False...") # TODO: maybe we should use wetext if python version is not 3.10? args.enable_tn = False text_norm = None else: text_norm = None assert (torch.cuda.is_available()) world_size, local_rank, rank = init_distributed() config = Config(model=args.model_path, enforce_eager=True, tensor_parallel_size=1, max_num_seqs=args.batch_size_dataloader, hf_config=CosyVoice2LLMConfig(fp16_flow=args.fp16_flow), rank=local_rank) model = CosyVoice2(config) set_all_random_seed(args.seed) dataset = AudioDataset(text_norm, model.llm.tokenizer, args.data_list, config) sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) dataloader = DataLoader(dataset, batch_size=args.batch_size_dataloader, num_workers=args.num_workers, pin_memory=True, sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn) total_steps = len(dataset) if local_rank == 0: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [INFO] - {args}") progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav", position=0, leave=True, dynamic_ncols=True) cpu_counts = os.cpu_count() executor = ThreadPoolExecutor(max_workers=min(args.batch_size_dataloader, cpu_counts // 8)) pending_futures = [] dataloader_iter = iter(dataloader) succeed_duration = 0.01 # avoid division by zero start_time = time.time() estimated_total_wavs = 0 succeed_wavs = 0 failed_wavs = 0 last_print_time = start_time while True: try: dataloader_start = time.time() batch = next(dataloader_iter) dataloader_time = time.time() - dataloader_start if len(batch['infos']) == 0: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [WARNING] - rank {rank} of {world_size}: No valid batch found, skipping this batch...") continue model_start = time.time() results_dict, timing_stats = model(**batch, batch_size_flow=args.batch_size_flow, only_llm=args.only_llm) model_time = time.time() - model_start estimated_total_wavs += len(results_dict['generated_wavs']) timing_stats['dataloader_time'] = dataloader_time timing_stats['model_inference_time'] = model_time if args.only_llm: results_dict['generated_wavs'] = [None] * len(results_dict['prompt_speech_tokens']) for i in range(len(results_dict['generated_wavs'])): future = executor.submit( save_file_async, results_dict['generated_wavs'][i], results_dict['prompt_speech_tokens'][i], results_dict['generated_speech_tokens'][i], batch['infos'][i].copy(), timing_stats.copy() ) pending_futures.append(future) completed_futures = [] for future in pending_futures: if future.done(): try: duration = future.result() succeed_duration += duration succeed_wavs += 1 except Exception as e: failed_wavs += 1 timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in async save task: {e}") completed_futures.append(future) for future in completed_futures: pending_futures.remove(future) if local_rank == 0: update_n = world_size * len(batch["prompt_text_tokens_for_llm"]) if progress_bar.n + update_n > progress_bar.total: progress_bar.update(progress_bar.total - progress_bar.n) else: progress_bar.update(update_n) current_time = time.time() if current_time - last_print_time >= 120 and not args.only_llm: elapsed_time = current_time - start_time avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0 estimated_total_duration = avg_duration * estimated_total_wavs current_rtf = elapsed_time / estimated_total_duration if estimated_total_duration > 0.01 else 0 timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Estimated RTF: {current_rtf:.5f}, Elapsed time: {elapsed_time:.2f}s") # noqa last_print_time = current_time except StopIteration: break except Exception as e: failed_wavs += 1 timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in main loop: {e}") continue total_time = time.time() - start_time if local_rank == 0: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...") for future in pending_futures: try: duration = future.result(timeout=60) succeed_duration += duration succeed_wavs += 1 except Exception as e: failed_wavs += 1 timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in final async save task: {e}") executor.shutdown(wait=True) if local_rank == 0: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [INFO] - All async save tasks completed.") progress_bar.close() if not args.only_llm: timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h), RTF: {total_time / succeed_duration:.5f}") # noqa dist.barrier() dist.destroy_process_group() if __name__ == "__main__": main()