import io import torch import torchaudio import s3tokenizer import onnxruntime from huggingface_hub import hf_hub_download import torchaudio.compliance.kaldi as kaldi from flashcosyvoice.modules.hifigan import HiFTGenerator from flashcosyvoice.utils.audio import mel_spectrogram from hyperpyyaml import load_hyperpyyaml class Token2wav(): def __init__(self, model_path, float16=False): self.float16 = float16 self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval() option = onnxruntime.SessionOptions() option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL option.intra_op_num_threads = 1 self.spk_model = onnxruntime.InferenceSession("token2wav/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"]) with open(f"{model_path}/flow.yaml", "r") as f: configs = load_hyperpyyaml(f) self.flow = configs['flow'] if float16: self.flow.half() self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True) self.flow.cuda().eval() self.hift = HiFTGenerator() hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()} self.hift.load_state_dict(hift_state_dict, strict=True) self.hift.cuda().eval() def __call__(self, generated_speech_tokens, prompt_wav): audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T] mels = s3tokenizer.log_mel_spectrogram(audio) mels, mels_lens = s3tokenizer.padding([mels]) prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda()) spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) spk_emb = torch.tensor(self.spk_model.run( None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} )[0], device='cuda') audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile') audio = audio.mean(dim=0, keepdim=True) # [1, T] if sample_rate != 24000: audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] prompt_mels = prompt_mel.unsqueeze(0).cuda() prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda') generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda') generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda') with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32): mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens, prompt_speech_tokens, prompt_speech_tokens_lens, prompt_mels, prompt_mels_lens, spk_emb, 10) wav, _ = self.hift(speech_feat=mel) output = io.BytesIO() torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav') return output.getvalue() if __name__ == '__main__': token2wav = Token2wav('/mnt/gpfs/lijingbei/Step-Audio-2-mini/token2wav') tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299] audio = token2wav(tokens, 'assets/default_male.wav') with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f: f.write(audio)