|
import numpy as np |
|
import torch |
|
|
|
def generate_embeddings(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None): |
|
""" |
|
为 Spark LLM 生成预测所需的输入嵌入 |
|
|
|
Args: |
|
model: Spark LLM 模型 |
|
tokenizer: 文本分词器 |
|
text: 要生成语音的文本 |
|
bicodec: BiCodecTokenizer 实例 |
|
prompt_text: 提示文本(可选) |
|
prompt_audio: 提示音频数组(可选) |
|
|
|
Returns: |
|
dict: 包含 input_embs 的字典,用于模型预测 |
|
""" |
|
device = next(model.parameters()).device |
|
|
|
|
|
if prompt_audio is not None: |
|
|
|
audio_data = np.array(prompt_audio, dtype=np.float32) |
|
target_sample_rate = bicodec.config['sample_rate'] |
|
|
|
|
|
|
|
|
|
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz") |
|
print(f"音频数据形状: {audio_data.shape}") |
|
|
|
|
|
global_tokens, semantic_tokens = bicodec.tokenize(audio_data) |
|
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
else: |
|
global_tokens = [] |
|
semantic_tokens = [] |
|
|
|
|
|
if prompt_text is not None: |
|
|
|
full_text = prompt_text + text |
|
|
|
initial_semantic_tokens = semantic_tokens.copy() |
|
else: |
|
full_text = text |
|
initial_semantic_tokens = [] |
|
|
|
|
|
text_tokens = tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) |
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device) |
|
|
|
|
|
text_embs = model.text_embedder(text_tokens_tensor) |
|
global_embs = model.global_embedder(global_tokens_tensor) |
|
semantic_embs = model.model.embeddings(semantic_tokens_tensor) |
|
|
|
|
|
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) |
|
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) |
|
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) |
|
|
|
|
|
input_embs = torch.cat([ |
|
tag_2_emb, |
|
text_embs, |
|
tag_0_emb, |
|
global_embs, |
|
tag_1_emb, |
|
semantic_embs |
|
], dim=0) |
|
|
|
|
|
input_embs = input_embs.unsqueeze(0) |
|
|
|
return { |
|
"input_embs": input_embs, |
|
"global_tokens": global_tokens_tensor, |
|
} |
|
|
|
def generate_embeddings_batch(model, tokenizer, texts, bicodec, prompt_text=None, prompt_audio=None): |
|
""" |
|
为 Spark LLM 批量生成预测所需的输入嵌入,支持多个文本的并行处理 |
|
|
|
Args: |
|
model: Spark LLM 模型 |
|
tokenizer: 文本分词器 |
|
texts: 要生成语音的文本列表 |
|
bicodec: BiCodecTokenizer 实例 |
|
prompt_text: 提示文本(可选) |
|
prompt_audio: 提示音频数组(可选) |
|
|
|
Returns: |
|
tuple: (embeddings_dict, attention_mask) 包含批量 input_embs 的字典和注意力掩码 |
|
""" |
|
device = next(model.parameters()).device |
|
dtype = next(model.parameters()).dtype |
|
batch_size = len(texts) |
|
|
|
|
|
if prompt_audio is not None: |
|
|
|
audio_data = np.array(prompt_audio, dtype=np.float32) |
|
target_sample_rate = bicodec.config['sample_rate'] |
|
|
|
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz") |
|
print(f"音频数据形状: {audio_data.shape}") |
|
|
|
|
|
global_tokens, semantic_tokens = bicodec.tokenize(audio_data) |
|
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
else: |
|
global_tokens = [] |
|
semantic_tokens = [] |
|
|
|
|
|
all_input_embs = [] |
|
all_attention_masks = [] |
|
|
|
for text in texts: |
|
|
|
if prompt_text is not None: |
|
full_text = prompt_text + text |
|
initial_semantic_tokens = semantic_tokens.copy() |
|
else: |
|
full_text = text |
|
initial_semantic_tokens = [] |
|
|
|
|
|
text_tokens = tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) |
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device) |
|
|
|
|
|
text_embs = model.text_embedder(text_tokens_tensor) |
|
global_embs = model.global_embedder(global_tokens_tensor) |
|
semantic_embs = model.model.embeddings(semantic_tokens_tensor) |
|
|
|
|
|
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) |
|
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) |
|
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) |
|
|
|
|
|
input_embs = torch.cat([ |
|
tag_2_emb, |
|
text_embs, |
|
tag_0_emb, |
|
global_embs, |
|
tag_1_emb, |
|
semantic_embs |
|
], dim=0) |
|
|
|
all_input_embs.append(input_embs) |
|
all_attention_masks.append(torch.ones(input_embs.shape[0], dtype=torch.long, device=device)) |
|
|
|
|
|
max_seq_len = max(emb.shape[0] for emb in all_input_embs) |
|
hidden_size = all_input_embs[0].shape[1] |
|
|
|
|
|
batch_input_embs = torch.zeros(batch_size, max_seq_len, hidden_size, device=device,dtype=dtype) |
|
batch_attention_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device) |
|
|
|
for i, (input_embs, attention_mask) in enumerate(zip(all_input_embs, all_attention_masks)): |
|
seq_len = input_embs.shape[0] |
|
|
|
batch_input_embs[i, -seq_len:, :] = input_embs |
|
batch_attention_mask[i, -seq_len:] = attention_mask |
|
|
|
|
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
batch_global_tokens = global_tokens_tensor.unsqueeze(0).expand(batch_size, -1) |
|
|
|
return { |
|
"input_embs": batch_input_embs, |
|
"global_tokens": batch_global_tokens, |
|
}, batch_attention_mask |