respark / epoch2 /tts_using_chatrwkv.py
yueyulin's picture
Upload folder using huggingface_hub
b3c4c5d verified
import os
os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '1'
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
print('add current dir to sys.path', current_dir)
sys.path.append(current_dir)
from rwkv.model import RWKV
model = RWKV(model="model_converted", strategy='cuda bf16')
device = "cuda:0"
print(model)
from sparktts.models.audio_tokenizer import BiCodecTokenizer
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
print(audio_tokenizer)
import soundfile as sf
import numpy as np
prompt_text = "我们并不是通过物理移动手段找到星河的。"
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
print(f"Loaded prompt audio from {prompt_audio_file}")
print(f"Original sampling rate: {sampling_rate}Hz")
print(f"Audio shape: {prompt_audio.shape}")
target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
print(f"Resampling from {sampling_rate}Hz to {target_sample_rate}Hz...")
from librosa import resample
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
prompt_audio = np.array(prompt_audio, dtype=np.float32)
print(f"Resampled audio shape: {prompt_audio.shape}")
else:
print(f"Audio sampling rate already matches target ({target_sample_rate}Hz)")
text = "二房他们已经接受了老爷子安排的:大房拿企业、二房拿钱的设定。富贵闲人他们也做了。在嫡长女和国资抢股权期间不出来搅局,就连老爷子的葬礼都没有露面,安安静静坐实老爷子一辈子的完美人设。"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
print(tokenizer)
audio_data = np.array(prompt_audio, dtype=np.float32)
target_sample_rate = audio_tokenizer.config['sample_rate']
# 检查是否需要重采样
# 注意:这里假设 prompt_audio 已经是从 soundfile 加载的,采样率信息在外部处理
# BiCodecTokenizer 期望 16kHz 采样率的音频
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz")
print(f"音频数据形状: {audio_data.shape}")
# 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens)
global_tokens, semantic_tokens = audio_tokenizer.tokenize(audio_data)
global_tokens = global_tokens.squeeze(0).squeeze(0).tolist()
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).tolist()
print(f"global_tokens: {global_tokens}")
print(f"semantic_tokens: {semantic_tokens}")
# new embedding: | semantic 8193 | tts_tag 3 | global 4096 | text 65536 |
text = prompt_text + text
text_tokens = tokenizer.encode(text, add_special_tokens=False)
TTS_TAG_0 = 8193
TTS_TAG_1 = 8194
TTS_TAG_2 = 8195
import torch
global_tokens = [i + 8196 for i in global_tokens]
text_tokens = [i + 8196+4096 for i in text_tokens]
print(f"global_tokens: {global_tokens}")
print(f"text_tokens: {text_tokens}")
# input_embs = torch.cat([
# tag_2_emb,
# text_embs,
# tag_0_emb,
# global_embs,
# tag_1_emb,
# semantic_embs
# ], dim=0)
all_idx = [TTS_TAG_2] + text_tokens + [TTS_TAG_0] + global_tokens + [TTS_TAG_1] + semantic_tokens
print(f'all_idx: {all_idx}')
import time
start_time = time.time()
x,state = model.forward(all_idx, None)
end_time = time.time()
print(f'time: {end_time - start_time}s, prefill speed: {len(all_idx) / (end_time - start_time)} tokens/s')
print(f'x: {x.shape}')
from torch.nn import functional as F
def sample_logits(logits, temperature=1.0, top_p=0.85, top_k=0):
if temperature == 0:
temperature = 1.0
top_p = 0
probs = F.softmax(logits.float(), dim=-1)
top_k = int(top_k)
# 'privateuseone' is the type of custom devices like `torch_directml.device()`
if probs.device.type in ['cpu', 'privateuseone']:
probs = probs.cpu().numpy()
sorted_ids = np.argsort(probs)
sorted_probs = probs[sorted_ids][::-1]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
probs = probs / np.sum(probs)
out = np.random.choice(a=len(probs), p=probs)
return int(out)
else:
sorted_ids = torch.argsort(probs)
sorted_probs = probs[sorted_ids]
sorted_probs = torch.flip(sorted_probs, dims=(0,))
cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
probs[probs < cutoff] = 0
if top_k < len(probs) and top_k > 0:
probs[sorted_ids[:-top_k]] = 0
if temperature != 1.0:
probs = probs ** (1.0 / temperature)
out = torch.multinomial(probs, num_samples=1)[0]
return int(out)
output_tokens = []
start_time = time.time()
while True:
sampled_id = sample_logits(x, temperature=1.0, top_p=0.95, top_k=20)
if sampled_id == 8192:
break
output_tokens.append(sampled_id)
x,state = model.forward([sampled_id], state)
end_time = time.time()
decode_time = end_time - start_time
print(f'output_tokens: {output_tokens}')
print(f'time: {decode_time}s, decode speed: {len(output_tokens) / decode_time} tokens/s')
global_tokens = torch.tensor([[i - 8196 for i in global_tokens]], dtype=torch.int32, device=device)
semantic_tokens = torch.tensor([output_tokens], dtype=torch.int32, device=device)
with torch.no_grad():
wav = audio_tokenizer.detokenize(global_tokens, semantic_tokens)
end_time = time.time()
all_time = end_time - start_time
print(f'all_time: {all_time}s, detokenize time : {all_time - decode_time}s')
sf.write('output_rwkvchat.wav', wav, target_sample_rate)
wav_duration = len(wav) / target_sample_rate
print(f'wav_duration: {wav_duration}s')
print(f'rtf: {all_time/wav_duration}')