File size: 6,016 Bytes
b3c4c5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
print('add current dir to sys.path', current_dir)
sys.path.append(current_dir)
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
import soundfile as sf
import numpy as np
import torch
from utilities import generate_embeddings
def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None, 
                   max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95, 
                   temperature=1.0, device="cuda:0"):
    """
    生成语音的函数
    
    Args:
        model: 语言模型
        tokenizer: 文本分词器
        text: 要生成语音的文本
        bicodec: BiCodecTokenizer 实例
        prompt_text: 提示文本(可选)
        prompt_audio: 提示音频数组(可选)
        max_new_tokens: 最大生成token数
        do_sample: 是否使用采样
        top_k: top-k采样参数
        top_p: top-p采样参数
        temperature: 温度参数
        device: 设备
    
    Returns:
        wav: 生成的音频波形
    """
    # 设置eos_token_id - 根据训练代码,eos_token_id = model.config.vocab_size - 1
    eos_token_id = model.config.vocab_size - 1
    print(f"EOS token ID: {eos_token_id}")
    
    # 生成输入嵌入
    embeddings = generate_embeddings(
        model=model,
        tokenizer=tokenizer,
        text=text,
        bicodec=bicodec,
        prompt_text=prompt_text,
        prompt_audio=prompt_audio
    )
    
    print("开始生成语音...")
    print(f"输入嵌入形状: {embeddings['input_embs'].shape}")
    global_tokens = embeddings['global_tokens'].unsqueeze(0)
    # 设置模型为评估模式
    print(f'embeddings dtype: {embeddings["input_embs"].dtype}')
    model.eval()
    
    with torch.no_grad():
        # 使用模型的generate方法
        generated_outputs = model.generate(
            inputs_embeds=embeddings['input_embs'],
            attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]),dtype=torch.long,device=device),
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            eos_token_id=eos_token_id,
            pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
            use_cache=True
        )
    print(f"generated_outputs: {generated_outputs}")
    
    print(f"生成的token数量: {generated_outputs.shape}")
    print(f"生成的token IDs: {generated_outputs.tolist()}")
    
    # 直接使用生成的token ID作为semantic tokens
    # 注意:这里生成的token ID是模型词表中的ID,不是原始tokenizer的词表
    semantic_tokens_tensor = generated_outputs[:,:-1]
    
    print(f"Semantic tokens shape: {semantic_tokens_tensor.shape}")
    
    #simulate streaming
    target_sample_rate = bicodec.config['sample_rate']
    print(f"Global tokens shape: {global_tokens.shape}")
    BUF_SIZE = 25 # since 50 tokens per second, 25 tokens is 0.5 second
    chunk_size = 125 # start to generate audio after 125 tokens
    buffered_semantic_tokens = torch.zeros((1, 0), dtype=torch.long, device=device)
    whole_wav = np.array([], dtype=np.float32)
    for i in range(0, semantic_tokens_tensor.shape[1], chunk_size):
        buffered_size = buffered_semantic_tokens.shape[1]
        current_semantic_tokens = semantic_tokens_tensor[:, i:i+chunk_size]
        print(f"generate segmant [{i}:{i+chunk_size}]: shape {current_semantic_tokens.shape}")
        current_semantic_tokens = torch.cat([buffered_semantic_tokens, current_semantic_tokens], dim=1)
        print(f"After concat: shape {current_semantic_tokens.shape} with buffered shape {buffered_semantic_tokens.shape}")
        buffered_semantic_tokens = current_semantic_tokens[:, -BUF_SIZE:]
        with torch.no_grad():
            wav = bicodec.detokenize(global_tokens, current_semantic_tokens)
        print(f"Generated audio shape: {wav.shape}")
        wav = wav[int(target_sample_rate * buffered_size/50):]
        print(f"After cut: shape {wav.shape}")
        whole_wav = np.concatenate([whole_wav, wav])
    print(f"Whole wav shape: {whole_wav.shape}")
    return whole_wav

device = 'cuda:2'

audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)

print(audio_tokenizer)

tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
print(tokenizer)
print(model)

model = model.bfloat16().to(device)
model.eval()

prompt_text = "我们并不是通过物理移动手段找到星河的。"
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
prompt_audio, sampling_rate = sf.read(prompt_audio_file)

print(f"Loaded prompt audio from {prompt_audio_file}")
print(f"Original sampling rate: {sampling_rate}Hz")
print(f"Audio shape: {prompt_audio.shape}")

target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
    print(f"Resampling from {sampling_rate}Hz to {target_sample_rate}Hz...")
    from librosa import resample
    prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
    prompt_audio = np.array(prompt_audio, dtype=np.float32)
    print(f"Resampled audio shape: {prompt_audio.shape}")
else:
    print(f"Audio sampling rate already matches target ({target_sample_rate}Hz)")

text = "二房他们已经接受了老爷子安排的:大房拿企业、二房拿钱的设定。富贵闲人他们也做了。在嫡长女和国资抢股权期间不出来搅局,就连老爷子的葬礼都没有露面,安安静静坐实老爷子一辈子的完美人设。"
wav = generate_speech(model, tokenizer, text, audio_tokenizer, prompt_audio=prompt_audio, device=device)
sf.write('output_streaming.wav', wav, target_sample_rate)