respark / epoch2 /utilities.py
yueyulin's picture
Upload folder using huggingface_hub
b3c4c5d verified
import numpy as np
import torch
import time
def generate_embeddings(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None):
"""
为 Spark LLM 生成预测所需的输入嵌入
Args:
model: Spark LLM 模型
tokenizer: 文本分词器
text: 要生成语音的文本
bicodec: BiCodecTokenizer 实例
prompt_text: 提示文本(可选)
prompt_audio: 提示音频数组(可选)
Returns:
dict: 包含 input_embs 的字典,用于模型预测
"""
device = next(model.parameters()).device
with torch.no_grad():
# 1. 处理提示音频,提取 global_tokens 和 semantic_tokens
if prompt_audio is not None:
# 确保音频数据是 float32 类型
audio_data = np.array(prompt_audio, dtype=np.float32)
target_sample_rate = bicodec.config['sample_rate']
# 检查是否需要重采样
# 注意:这里假设 prompt_audio 已经是从 soundfile 加载的,采样率信息在外部处理
# BiCodecTokenizer 期望 16kHz 采样率的音频
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz")
print(f"音频数据形状: {audio_data.shape}")
# 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens)
global_tokens, semantic_tokens = bicodec.tokenize(audio_data)
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
else:
global_tokens = []
semantic_tokens = []
# 2. 处理文本
if prompt_text is not None:
# 连接提示文本和目标文本
full_text = prompt_text + text
# 初始的 semantic tokens 等于 prompt_audio 提取的 semantic tokens
initial_semantic_tokens = semantic_tokens.copy()
else:
full_text = text
initial_semantic_tokens = []
# 3. 获取文本 tokens
text_tokens = tokenizer.encode(full_text, add_special_tokens=False)
# 4. 转换为张量
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device)
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device)
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device)
# 5. 获取嵌入
text_embs = model.text_embedder(text_tokens_tensor)
global_embs = model.global_embedder(global_tokens_tensor)
semantic_embs = model.model.embeddings(semantic_tokens_tensor)
# 6. 获取特殊标记嵌入
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device))
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device))
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device))
# 7. 连接嵌入
input_embs = torch.cat([
tag_2_emb,
text_embs,
tag_0_emb,
global_embs,
tag_1_emb,
semantic_embs
], dim=0)
# 8. 添加批次维度
input_embs = input_embs.unsqueeze(0) # [1, seq_len, hidden_size]
return {
"input_embs": input_embs,
"global_tokens": global_tokens_tensor,
}
def generate_embeddings_batch(model, tokenizer, texts, bicodec, prompt_text=None, prompt_audio=None):
"""
为 Spark LLM 批量生成预测所需的输入嵌入,支持多个文本的并行处理
Args:
model: Spark LLM 模型
tokenizer: 文本分词器
texts: 要生成语音的文本列表
bicodec: BiCodecTokenizer 实例
prompt_text: 提示文本(可选)
prompt_audio: 提示音频数组(可选)
Returns:
tuple: (embeddings_dict, attention_mask) 包含批量 input_embs 的字典和注意力掩码
"""
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
batch_size = len(texts)
with torch.no_grad():
# 1. 处理提示音频,提取 global_tokens 和 semantic_tokens
if prompt_audio is not None:
# 确保音频数据是 float32 类型
audio_data = np.array(prompt_audio, dtype=np.float32)
target_sample_rate = bicodec.config['sample_rate']
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz")
print(f"音频数据形状: {audio_data.shape}")
# 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens)
global_tokens, semantic_tokens = bicodec.tokenize(audio_data)
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist()
else:
global_tokens = []
semantic_tokens = []
# 2. 处理所有文本,获取每个样本的嵌入组件
all_input_embs = []
all_attention_masks = []
for text in texts:
# 处理单个文本
if prompt_text is not None:
full_text = prompt_text + text
initial_semantic_tokens = semantic_tokens.copy()
else:
full_text = text
initial_semantic_tokens = []
# 获取文本 tokens
text_tokens = tokenizer.encode(full_text, add_special_tokens=False)
# 转换为张量
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device)
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device)
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device)
# 获取嵌入
text_embs = model.text_embedder(text_tokens_tensor)
global_embs = model.global_embedder(global_tokens_tensor)
semantic_embs = model.model.embeddings(semantic_tokens_tensor)
# 获取特殊标记嵌入
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device))
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device))
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device))
# 连接嵌入
input_embs = torch.cat([
tag_2_emb,
text_embs,
tag_0_emb,
global_embs,
tag_1_emb,
semantic_embs
], dim=0) # [seq_len, hidden_size]
all_input_embs.append(input_embs)
all_attention_masks.append(torch.ones(input_embs.shape[0], dtype=torch.long, device=device))
# 3. 找到最大序列长度
max_seq_len = max(emb.shape[0] for emb in all_input_embs)
hidden_size = all_input_embs[0].shape[1]
# 4. 创建批量张量,使用 left padding 和零填充
batch_input_embs = torch.zeros(batch_size, max_seq_len, hidden_size, device=device, dtype=dtype)
batch_attention_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device)
for i, (input_embs, attention_mask) in enumerate(zip(all_input_embs, all_attention_masks)):
seq_len = input_embs.shape[0]
# Left padding: 将序列放在右侧,左侧填充零
batch_input_embs[i, -seq_len:, :] = input_embs
batch_attention_mask[i, -seq_len:] = attention_mask
# 5. 创建 global_tokens 的批量版本
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device, requires_grad=False)
batch_global_tokens = global_tokens_tensor.unsqueeze(0).expand(batch_size, -1)
return {
"input_embs": batch_input_embs,
"global_tokens": batch_global_tokens,
}, batch_attention_mask
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
if rep_num >= win_size * tau_r:
top_ids = random_sampling(weighted_scores)
return top_ids
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
prob, indices = [], []
cum_prob = 0.0
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
for i in range(len(sorted_idx)):
# sampling both top-p and numbers.
if cum_prob < top_p and len(prob) < top_k:
cum_prob += sorted_value[i]
prob.append(sorted_value[i])
indices.append(sorted_idx[i])
else:
break
prob = torch.tensor(prob).to(weighted_scores)
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
top_ids = indices[prob.multinomial(1, replacement=True)]
return top_ids
def random_sampling(weighted_scores):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
return top_ids
def generate(model,
inputs_embeds,
attention_mask,
new_max_tokens,
top_k,
top_p,
temperate,
eos_token_id,
pad_token_id,
past_key_values
):
"""
seperate two stages of generation:
1. prefill
2. decode
we will measure the time of each stage and the total time
"""
start_time = time.time()
model.eval()
batch_size = inputs_embeds.shape[0]
decoded_tokens = [[] for _ in range(batch_size)]
is_decoding = [True for _ in range(batch_size)]
with torch.no_grad():
# 1. prefill
outputs = model.model.forward(
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=True,
output_attentions=False,
output_hidden_states=True,
return_dict=False
)
hidden_states = outputs[0]
past_key_values = outputs[1]
prefill_time = time.time() - start_time
tokens = attention_mask.shape[0]*attention_mask.shape[1]
print(f"Prefill time: {prefill_time} seconds, all tokens is {tokens}, speed is {tokens/prefill_time} tokens/s ")
# 2. decode
start_time = time.time()
#sampling the logits using top_k, top_p, temperature
decoded_tokens_size = 0
while True:
logits = model.lm_head(hidden_states)
last_time_decoded = []
logits = logits[:, -1, :]
continue_decoding = False
for i in range(batch_size):
if is_decoding[i]:
logits_i = logits[i, :]
top_ids = ras_sampling(logits_i, decoded_tokens[i], top_p=top_p, top_k=top_k).item()
decoded_tokens[i].append(top_ids)
last_time_decoded.append([top_ids])
if top_ids == eos_token_id:
is_decoding[i] = False
else:
continue_decoding = True
decoded_tokens_size += 1
else:
decoded_tokens[i].append(pad_token_id)
last_time_decoded.append([pad_token_id])
if not continue_decoding:
break
last_time_decoded = torch.tensor(last_time_decoded, dtype=torch.long, device=device)
lm_input = model.get_input_embeddings()(last_time_decoded)
outputs = model.model.forward(
inputs_embeds=lm_input,
past_key_values=past_key_values,
use_cache=True,
output_attentions=False,
output_hidden_states=True,
return_dict=False
)
hidden_states = outputs[0]
past_key_values = outputs[1]
decode_time = time.time() - start_time
print(f"Decode time: {decode_time} seconds, all tokens is {decoded_tokens_size}, speed is {decoded_tokens_size/decode_time} tokens/s ")
print(f"decoded_tokens: {decoded_tokens}")
return decoded_tokens, past_key_values
if __name__ == "__main__":
import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
print('add current dir to sys.path', current_dir)
sys.path.append(current_dir)
device = 'cuda:2'
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
import soundfile as sf
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
print(audio_tokenizer)
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
print(tokenizer)
print(model)
model = model.bfloat16().to(device)
model.eval()
prompt_text = "我们并不是通过物理移动手段找到星河的。"
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
print(f"Loaded prompt audio from {prompt_audio_file}")
print(f"Original sampling rate: {sampling_rate}Hz")
print(f"Audio shape: {prompt_audio.shape}")
target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
print(f"Resampling from {sampling_rate}Hz to {target_sample_rate}Hz...")
from librosa import resample
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
prompt_audio = np.array(prompt_audio, dtype=np.float32)
print(f"Resampled audio shape: {prompt_audio.shape}")
else:
print(f"Audio sampling rate already matches target ({target_sample_rate}Hz)")
texts = ["为了点燃青少年对科技的热情,培养他们的创新思维与动手能力,杏花岭区巨轮街道社区教育学校携手中车社区教育分校,与太原市科学技术协会联手,于暑期精心策划了一场别开生面的青少年数智技术服务港探索之旅,吸引了众多社区青少年的积极参与。"]
eos_token_id = model.config.vocab_size - 1
print(f"EOS token ID: {eos_token_id}")
# 生成输入嵌入
embeddings,attention_mask = generate_embeddings_batch(
model=model,
tokenizer=tokenizer,
texts=texts,
bicodec=audio_tokenizer,
prompt_text=prompt_text,
prompt_audio=prompt_audio
)
input_embs = embeddings['input_embs']
print(f"input_embs shape: {input_embs.shape}")
print(f"attention_mask shape: {attention_mask.shape}")
print(f"input_embs dtype: {input_embs.dtype}")
print(f"attention_mask dtype: {attention_mask.dtype}")
print(f"input_embs: {input_embs}")
print(f"attention_mask: {attention_mask}")
print(f"input_embs: {input_embs}")
with torch.no_grad():
generate(model,
input_embs,
attention_mask,
new_max_tokens=1000,
top_k=25,
top_p=0.95,
temperate=1.0,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
past_key_values=None)
with torch.no_grad():
audio_tokens,past_key_values = generate(model,
input_embs,
attention_mask,
new_max_tokens=1000,
top_k=50,
top_p=0.8,
temperate=1.0,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
past_key_values=None)
audio_tokens = torch.tensor(audio_tokens, dtype=torch.long, device=device)
audio_tokens = audio_tokens[:,:-1]
print(f"audio_tokens: {audio_tokens}")
print(f"past_key_values: {past_key_values}")
global_tokens = embeddings['global_tokens']
print(f"global_tokens shape: {global_tokens.shape}")
print(f"audio_tokens shape: {audio_tokens.shape}")
with torch.no_grad():
wav = audio_tokenizer.detokenize(global_tokens, audio_tokens)
print(f"wav shape: {wav.shape}")
sf.write('test.wav', wav, audio_tokenizer.config['sample_rate'])