|
import numpy as np |
|
import torch |
|
import time |
|
|
|
def generate_embeddings(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None): |
|
""" |
|
为 Spark LLM 生成预测所需的输入嵌入 |
|
|
|
Args: |
|
model: Spark LLM 模型 |
|
tokenizer: 文本分词器 |
|
text: 要生成语音的文本 |
|
bicodec: BiCodecTokenizer 实例 |
|
prompt_text: 提示文本(可选) |
|
prompt_audio: 提示音频数组(可选) |
|
|
|
Returns: |
|
dict: 包含 input_embs 的字典,用于模型预测 |
|
""" |
|
device = next(model.parameters()).device |
|
|
|
with torch.no_grad(): |
|
|
|
if prompt_audio is not None: |
|
|
|
audio_data = np.array(prompt_audio, dtype=np.float32) |
|
target_sample_rate = bicodec.config['sample_rate'] |
|
|
|
|
|
|
|
|
|
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz") |
|
print(f"音频数据形状: {audio_data.shape}") |
|
|
|
|
|
global_tokens, semantic_tokens = bicodec.tokenize(audio_data) |
|
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
else: |
|
global_tokens = [] |
|
semantic_tokens = [] |
|
|
|
|
|
if prompt_text is not None: |
|
|
|
full_text = prompt_text + text |
|
|
|
initial_semantic_tokens = semantic_tokens.copy() |
|
else: |
|
full_text = text |
|
initial_semantic_tokens = [] |
|
|
|
|
|
text_tokens = tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) |
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device) |
|
|
|
|
|
text_embs = model.text_embedder(text_tokens_tensor) |
|
global_embs = model.global_embedder(global_tokens_tensor) |
|
semantic_embs = model.model.embeddings(semantic_tokens_tensor) |
|
|
|
|
|
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) |
|
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) |
|
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) |
|
|
|
|
|
input_embs = torch.cat([ |
|
tag_2_emb, |
|
text_embs, |
|
tag_0_emb, |
|
global_embs, |
|
tag_1_emb, |
|
semantic_embs |
|
], dim=0) |
|
|
|
|
|
input_embs = input_embs.unsqueeze(0) |
|
|
|
return { |
|
"input_embs": input_embs, |
|
"global_tokens": global_tokens_tensor, |
|
} |
|
|
|
def generate_embeddings_batch(model, tokenizer, texts, bicodec, prompt_text=None, prompt_audio=None): |
|
""" |
|
为 Spark LLM 批量生成预测所需的输入嵌入,支持多个文本的并行处理 |
|
|
|
Args: |
|
model: Spark LLM 模型 |
|
tokenizer: 文本分词器 |
|
texts: 要生成语音的文本列表 |
|
bicodec: BiCodecTokenizer 实例 |
|
prompt_text: 提示文本(可选) |
|
prompt_audio: 提示音频数组(可选) |
|
|
|
Returns: |
|
tuple: (embeddings_dict, attention_mask) 包含批量 input_embs 的字典和注意力掩码 |
|
""" |
|
device = next(model.parameters()).device |
|
dtype = next(model.parameters()).dtype |
|
batch_size = len(texts) |
|
|
|
with torch.no_grad(): |
|
|
|
if prompt_audio is not None: |
|
|
|
audio_data = np.array(prompt_audio, dtype=np.float32) |
|
target_sample_rate = bicodec.config['sample_rate'] |
|
|
|
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz") |
|
print(f"音频数据形状: {audio_data.shape}") |
|
|
|
|
|
global_tokens, semantic_tokens = bicodec.tokenize(audio_data) |
|
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
else: |
|
global_tokens = [] |
|
semantic_tokens = [] |
|
|
|
|
|
all_input_embs = [] |
|
all_attention_masks = [] |
|
|
|
for text in texts: |
|
|
|
if prompt_text is not None: |
|
full_text = prompt_text + text |
|
initial_semantic_tokens = semantic_tokens.copy() |
|
else: |
|
full_text = text |
|
initial_semantic_tokens = [] |
|
|
|
|
|
text_tokens = tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) |
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device) |
|
|
|
|
|
text_embs = model.text_embedder(text_tokens_tensor) |
|
global_embs = model.global_embedder(global_tokens_tensor) |
|
semantic_embs = model.model.embeddings(semantic_tokens_tensor) |
|
|
|
|
|
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) |
|
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) |
|
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) |
|
|
|
|
|
input_embs = torch.cat([ |
|
tag_2_emb, |
|
text_embs, |
|
tag_0_emb, |
|
global_embs, |
|
tag_1_emb, |
|
semantic_embs |
|
], dim=0) |
|
|
|
all_input_embs.append(input_embs) |
|
all_attention_masks.append(torch.ones(input_embs.shape[0], dtype=torch.long, device=device)) |
|
|
|
|
|
max_seq_len = max(emb.shape[0] for emb in all_input_embs) |
|
hidden_size = all_input_embs[0].shape[1] |
|
|
|
|
|
batch_input_embs = torch.zeros(batch_size, max_seq_len, hidden_size, device=device, dtype=dtype) |
|
batch_attention_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device) |
|
|
|
for i, (input_embs, attention_mask) in enumerate(zip(all_input_embs, all_attention_masks)): |
|
seq_len = input_embs.shape[0] |
|
|
|
batch_input_embs[i, -seq_len:, :] = input_embs |
|
batch_attention_mask[i, -seq_len:] = attention_mask |
|
|
|
|
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device, requires_grad=False) |
|
batch_global_tokens = global_tokens_tensor.unsqueeze(0).expand(batch_size, -1) |
|
|
|
return { |
|
"input_embs": batch_input_embs, |
|
"global_tokens": batch_global_tokens, |
|
}, batch_attention_mask |
|
|
|
|
|
def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): |
|
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) |
|
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() |
|
if rep_num >= win_size * tau_r: |
|
top_ids = random_sampling(weighted_scores) |
|
return top_ids |
|
|
|
|
|
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): |
|
prob, indices = [], [] |
|
cum_prob = 0.0 |
|
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) |
|
for i in range(len(sorted_idx)): |
|
|
|
if cum_prob < top_p and len(prob) < top_k: |
|
cum_prob += sorted_value[i] |
|
prob.append(sorted_value[i]) |
|
indices.append(sorted_idx[i]) |
|
else: |
|
break |
|
prob = torch.tensor(prob).to(weighted_scores) |
|
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) |
|
top_ids = indices[prob.multinomial(1, replacement=True)] |
|
return top_ids |
|
|
|
def random_sampling(weighted_scores): |
|
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) |
|
return top_ids |
|
|
|
def generate(model, |
|
inputs_embeds, |
|
attention_mask, |
|
new_max_tokens, |
|
top_k, |
|
top_p, |
|
temperate, |
|
eos_token_id, |
|
pad_token_id, |
|
past_key_values |
|
): |
|
""" |
|
seperate two stages of generation: |
|
1. prefill |
|
2. decode |
|
we will measure the time of each stage and the total time |
|
""" |
|
start_time = time.time() |
|
model.eval() |
|
batch_size = inputs_embeds.shape[0] |
|
decoded_tokens = [[] for _ in range(batch_size)] |
|
is_decoding = [True for _ in range(batch_size)] |
|
with torch.no_grad(): |
|
|
|
outputs = model.model.forward( |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
output_attentions=False, |
|
output_hidden_states=True, |
|
return_dict=False |
|
) |
|
hidden_states = outputs[0] |
|
past_key_values = outputs[1] |
|
prefill_time = time.time() - start_time |
|
tokens = attention_mask.shape[0]*attention_mask.shape[1] |
|
print(f"Prefill time: {prefill_time} seconds, all tokens is {tokens}, speed is {tokens/prefill_time} tokens/s ") |
|
|
|
|
|
start_time = time.time() |
|
|
|
decoded_tokens_size = 0 |
|
while True: |
|
logits = model.lm_head(hidden_states) |
|
last_time_decoded = [] |
|
logits = logits[:, -1, :] |
|
continue_decoding = False |
|
for i in range(batch_size): |
|
if is_decoding[i]: |
|
logits_i = logits[i, :] |
|
top_ids = ras_sampling(logits_i, decoded_tokens[i], top_p=top_p, top_k=top_k).item() |
|
decoded_tokens[i].append(top_ids) |
|
last_time_decoded.append([top_ids]) |
|
if top_ids == eos_token_id: |
|
is_decoding[i] = False |
|
else: |
|
continue_decoding = True |
|
decoded_tokens_size += 1 |
|
else: |
|
decoded_tokens[i].append(pad_token_id) |
|
last_time_decoded.append([pad_token_id]) |
|
if not continue_decoding: |
|
break |
|
last_time_decoded = torch.tensor(last_time_decoded, dtype=torch.long, device=device) |
|
lm_input = model.get_input_embeddings()(last_time_decoded) |
|
outputs = model.model.forward( |
|
inputs_embeds=lm_input, |
|
past_key_values=past_key_values, |
|
use_cache=True, |
|
output_attentions=False, |
|
output_hidden_states=True, |
|
return_dict=False |
|
) |
|
hidden_states = outputs[0] |
|
past_key_values = outputs[1] |
|
decode_time = time.time() - start_time |
|
print(f"Decode time: {decode_time} seconds, all tokens is {decoded_tokens_size}, speed is {decoded_tokens_size/decode_time} tokens/s ") |
|
print(f"decoded_tokens: {decoded_tokens}") |
|
return decoded_tokens, past_key_values |
|
|
|
if __name__ == "__main__": |
|
import os |
|
import sys |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
print('add current dir to sys.path', current_dir) |
|
sys.path.append(current_dir) |
|
device = 'cuda:2' |
|
from sparktts.models.audio_tokenizer import BiCodecTokenizer |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import soundfile as sf |
|
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device) |
|
|
|
print(audio_tokenizer) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True) |
|
print(tokenizer) |
|
print(model) |
|
|
|
model = model.bfloat16().to(device) |
|
model.eval() |
|
prompt_text = "我们并不是通过物理移动手段找到星河的。" |
|
prompt_audio_file = os.path.join(current_dir, 'kafka.wav') |
|
prompt_audio, sampling_rate = sf.read(prompt_audio_file) |
|
|
|
print(f"Loaded prompt audio from {prompt_audio_file}") |
|
print(f"Original sampling rate: {sampling_rate}Hz") |
|
print(f"Audio shape: {prompt_audio.shape}") |
|
target_sample_rate = audio_tokenizer.config['sample_rate'] |
|
if sampling_rate != target_sample_rate: |
|
print(f"Resampling from {sampling_rate}Hz to {target_sample_rate}Hz...") |
|
from librosa import resample |
|
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate) |
|
prompt_audio = np.array(prompt_audio, dtype=np.float32) |
|
print(f"Resampled audio shape: {prompt_audio.shape}") |
|
else: |
|
print(f"Audio sampling rate already matches target ({target_sample_rate}Hz)") |
|
texts = ["为了点燃青少年对科技的热情,培养他们的创新思维与动手能力,杏花岭区巨轮街道社区教育学校携手中车社区教育分校,与太原市科学技术协会联手,于暑期精心策划了一场别开生面的青少年数智技术服务港探索之旅,吸引了众多社区青少年的积极参与。"] |
|
eos_token_id = model.config.vocab_size - 1 |
|
print(f"EOS token ID: {eos_token_id}") |
|
|
|
embeddings,attention_mask = generate_embeddings_batch( |
|
model=model, |
|
tokenizer=tokenizer, |
|
texts=texts, |
|
bicodec=audio_tokenizer, |
|
prompt_text=prompt_text, |
|
prompt_audio=prompt_audio |
|
) |
|
input_embs = embeddings['input_embs'] |
|
print(f"input_embs shape: {input_embs.shape}") |
|
print(f"attention_mask shape: {attention_mask.shape}") |
|
print(f"input_embs dtype: {input_embs.dtype}") |
|
print(f"attention_mask dtype: {attention_mask.dtype}") |
|
print(f"input_embs: {input_embs}") |
|
print(f"attention_mask: {attention_mask}") |
|
print(f"input_embs: {input_embs}") |
|
with torch.no_grad(): |
|
generate(model, |
|
input_embs, |
|
attention_mask, |
|
new_max_tokens=1000, |
|
top_k=25, |
|
top_p=0.95, |
|
temperate=1.0, |
|
eos_token_id=eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, |
|
past_key_values=None) |
|
|
|
with torch.no_grad(): |
|
audio_tokens,past_key_values = generate(model, |
|
input_embs, |
|
attention_mask, |
|
new_max_tokens=1000, |
|
top_k=50, |
|
top_p=0.8, |
|
temperate=1.0, |
|
eos_token_id=eos_token_id, |
|
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, |
|
past_key_values=None) |
|
audio_tokens = torch.tensor(audio_tokens, dtype=torch.long, device=device) |
|
audio_tokens = audio_tokens[:,:-1] |
|
print(f"audio_tokens: {audio_tokens}") |
|
print(f"past_key_values: {past_key_values}") |
|
global_tokens = embeddings['global_tokens'] |
|
print(f"global_tokens shape: {global_tokens.shape}") |
|
print(f"audio_tokens shape: {audio_tokens.shape}") |
|
with torch.no_grad(): |
|
wav = audio_tokenizer.detokenize(global_tokens, audio_tokens) |
|
print(f"wav shape: {wav.shape}") |
|
sf.write('test.wav', wav, audio_tokenizer.config['sample_rate']) |