Step-Audio-2-mini / token2wav.py
Steveeeeeeen's picture
Steveeeeeeen HF Staff
add model
7e6946d
raw
history blame
4.22 kB
import io
import torch
import torchaudio
import s3tokenizer
import onnxruntime
import torchaudio.compliance.kaldi as kaldi
from flashcosyvoice.modules.hifigan import HiFTGenerator
from flashcosyvoice.utils.audio import mel_spectrogram
from hyperpyyaml import load_hyperpyyaml
class Token2wav():
def __init__(self, model_path, float16=False):
self.float16 = float16
self.audio_tokenizer = s3tokenizer.load_model(f"{model_path}/speech_tokenizer_v2_25hz.onnx").cuda().eval()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.spk_model = onnxruntime.InferenceSession(f"{model_path}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
with open(f"{model_path}/flow.yaml", "r") as f:
configs = load_hyperpyyaml(f)
self.flow = configs['flow']
if float16:
self.flow.half()
self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True)
self.flow.cuda().eval()
self.hift = HiFTGenerator()
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.cuda().eval()
def __call__(self, generated_speech_tokens, prompt_wav):
audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T]
mels = s3tokenizer.log_mel_spectrogram(audio)
mels, mels_lens = s3tokenizer.padding([mels])
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda())
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
spk_emb = torch.tensor(self.spk_model.run(
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
)[0], device='cuda')
audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile')
audio = audio.mean(dim=0, keepdim=True) # [1, T]
if sample_rate != 24000:
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
prompt_mels = prompt_mel.unsqueeze(0).cuda()
prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda')
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda')
with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32):
mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens,
prompt_speech_tokens, prompt_speech_tokens_lens,
prompt_mels, prompt_mels_lens, spk_emb, 10)
wav, _ = self.hift(speech_feat=mel)
output = io.BytesIO()
torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav')
return output.getvalue()
if __name__ == '__main__':
token2wav = Token2wav('/mnt/gpfs/lijingbei/Step-Audio-2-mini/token2wav')
tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299]
audio = token2wav(tokens, 'assets/default_male.wav')
with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f:
f.write(audio)