import numpy as np import torch import time def generate_embeddings(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None): """ 为 Spark LLM 生成预测所需的输入嵌入 Args: model: Spark LLM 模型 tokenizer: 文本分词器 text: 要生成语音的文本 bicodec: BiCodecTokenizer 实例 prompt_text: 提示文本(可选) prompt_audio: 提示音频数组(可选) Returns: dict: 包含 input_embs 的字典,用于模型预测 """ device = next(model.parameters()).device with torch.no_grad(): # 1. 处理提示音频,提取 global_tokens 和 semantic_tokens if prompt_audio is not None: # 确保音频数据是 float32 类型 audio_data = np.array(prompt_audio, dtype=np.float32) target_sample_rate = bicodec.config['sample_rate'] # 检查是否需要重采样 # 注意:这里假设 prompt_audio 已经是从 soundfile 加载的,采样率信息在外部处理 # BiCodecTokenizer 期望 16kHz 采样率的音频 print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz") print(f"音频数据形状: {audio_data.shape}") # 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens) global_tokens, semantic_tokens = bicodec.tokenize(audio_data) global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() else: global_tokens = [] semantic_tokens = [] # 2. 处理文本 if prompt_text is not None: # 连接提示文本和目标文本 full_text = prompt_text + text # 初始的 semantic tokens 等于 prompt_audio 提取的 semantic tokens initial_semantic_tokens = semantic_tokens.copy() else: full_text = text initial_semantic_tokens = [] # 3. 获取文本 tokens text_tokens = tokenizer.encode(full_text, add_special_tokens=False) # 4. 转换为张量 text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device) # 5. 获取嵌入 text_embs = model.text_embedder(text_tokens_tensor) global_embs = model.global_embedder(global_tokens_tensor) semantic_embs = model.model.embeddings(semantic_tokens_tensor) # 6. 获取特殊标记嵌入 tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) # 7. 连接嵌入 input_embs = torch.cat([ tag_2_emb, text_embs, tag_0_emb, global_embs, tag_1_emb, semantic_embs ], dim=0) # 8. 添加批次维度 input_embs = input_embs.unsqueeze(0) # [1, seq_len, hidden_size] return { "input_embs": input_embs, "global_tokens": global_tokens_tensor, } def generate_embeddings_batch(model, tokenizer, texts, bicodec, prompt_text=None, prompt_audio=None): """ 为 Spark LLM 批量生成预测所需的输入嵌入,支持多个文本的并行处理 Args: model: Spark LLM 模型 tokenizer: 文本分词器 texts: 要生成语音的文本列表 bicodec: BiCodecTokenizer 实例 prompt_text: 提示文本(可选) prompt_audio: 提示音频数组(可选) Returns: tuple: (embeddings_dict, attention_mask) 包含批量 input_embs 的字典和注意力掩码 """ device = next(model.parameters()).device dtype = next(model.parameters()).dtype batch_size = len(texts) with torch.no_grad(): # 1. 处理提示音频,提取 global_tokens 和 semantic_tokens if prompt_audio is not None: # 确保音频数据是 float32 类型 audio_data = np.array(prompt_audio, dtype=np.float32) target_sample_rate = bicodec.config['sample_rate'] print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz") print(f"音频数据形状: {audio_data.shape}") # 使用 BiCodec 提取 tokens (返回顺序: global_tokens, semantic_tokens) global_tokens, semantic_tokens = bicodec.tokenize(audio_data) global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() else: global_tokens = [] semantic_tokens = [] # 2. 处理所有文本,获取每个样本的嵌入组件 all_input_embs = [] all_attention_masks = [] for text in texts: # 处理单个文本 if prompt_text is not None: full_text = prompt_text + text initial_semantic_tokens = semantic_tokens.copy() else: full_text = text initial_semantic_tokens = [] # 获取文本 tokens text_tokens = tokenizer.encode(full_text, add_special_tokens=False) # 转换为张量 text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device) # 获取嵌入 text_embs = model.text_embedder(text_tokens_tensor) global_embs = model.global_embedder(global_tokens_tensor) semantic_embs = model.model.embeddings(semantic_tokens_tensor) # 获取特殊标记嵌入 tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) # 连接嵌入 input_embs = torch.cat([ tag_2_emb, text_embs, tag_0_emb, global_embs, tag_1_emb, semantic_embs ], dim=0) # [seq_len, hidden_size] all_input_embs.append(input_embs) all_attention_masks.append(torch.ones(input_embs.shape[0], dtype=torch.long, device=device)) # 3. 找到最大序列长度 max_seq_len = max(emb.shape[0] for emb in all_input_embs) hidden_size = all_input_embs[0].shape[1] # 4. 创建批量张量,使用 left padding 和零填充 batch_input_embs = torch.zeros(batch_size, max_seq_len, hidden_size, device=device, dtype=dtype) batch_attention_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.long, device=device) for i, (input_embs, attention_mask) in enumerate(zip(all_input_embs, all_attention_masks)): seq_len = input_embs.shape[0] # Left padding: 将序列放在右侧,左侧填充零 batch_input_embs[i, -seq_len:, :] = input_embs batch_attention_mask[i, -seq_len:] = attention_mask # 5. 创建 global_tokens 的批量版本 global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device, requires_grad=False) batch_global_tokens = global_tokens_tensor.unsqueeze(0).expand(batch_size, -1) return { "input_embs": batch_input_embs, "global_tokens": batch_global_tokens, }, batch_attention_mask # Repetition Aware Sampling in VALL-E 2 def ras_sampling(weighted_scores, decoded_tokens, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() if rep_num >= win_size * tau_r: top_ids = random_sampling(weighted_scores) return top_ids def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): prob, indices = [], [] cum_prob = 0.0 sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) for i in range(len(sorted_idx)): # sampling both top-p and numbers. if cum_prob < top_p and len(prob) < top_k: cum_prob += sorted_value[i] prob.append(sorted_value[i]) indices.append(sorted_idx[i]) else: break prob = torch.tensor(prob).to(weighted_scores) indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) top_ids = indices[prob.multinomial(1, replacement=True)] return top_ids def random_sampling(weighted_scores): top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) return top_ids def generate(model, inputs_embeds, attention_mask, new_max_tokens, top_k, top_p, temperate, eos_token_id, pad_token_id, past_key_values ): """ seperate two stages of generation: 1. prefill 2. decode we will measure the time of each stage and the total time """ start_time = time.time() model.eval() batch_size = inputs_embeds.shape[0] decoded_tokens = [[] for _ in range(batch_size)] is_decoding = [True for _ in range(batch_size)] with torch.no_grad(): # 1. prefill outputs = model.model.forward( attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=True, output_attentions=False, output_hidden_states=True, return_dict=False ) hidden_states = outputs[0] past_key_values = outputs[1] prefill_time = time.time() - start_time tokens = attention_mask.shape[0]*attention_mask.shape[1] print(f"Prefill time: {prefill_time} seconds, all tokens is {tokens}, speed is {tokens/prefill_time} tokens/s ") # 2. decode start_time = time.time() #sampling the logits using top_k, top_p, temperature decoded_tokens_size = 0 while True: logits = model.lm_head(hidden_states) last_time_decoded = [] logits = logits[:, -1, :] continue_decoding = False for i in range(batch_size): if is_decoding[i]: logits_i = logits[i, :] top_ids = ras_sampling(logits_i, decoded_tokens[i], top_p=top_p, top_k=top_k).item() decoded_tokens[i].append(top_ids) last_time_decoded.append([top_ids]) if top_ids == eos_token_id: is_decoding[i] = False else: continue_decoding = True decoded_tokens_size += 1 else: decoded_tokens[i].append(pad_token_id) last_time_decoded.append([pad_token_id]) if not continue_decoding: break last_time_decoded = torch.tensor(last_time_decoded, dtype=torch.long, device=device) lm_input = model.get_input_embeddings()(last_time_decoded) outputs = model.model.forward( inputs_embeds=lm_input, past_key_values=past_key_values, use_cache=True, output_attentions=False, output_hidden_states=True, return_dict=False ) hidden_states = outputs[0] past_key_values = outputs[1] decode_time = time.time() - start_time print(f"Decode time: {decode_time} seconds, all tokens is {decoded_tokens_size}, speed is {decoded_tokens_size/decode_time} tokens/s ") print(f"decoded_tokens: {decoded_tokens}") return decoded_tokens, past_key_values if __name__ == "__main__": import os import sys current_dir = os.path.dirname(os.path.abspath(__file__)) print('add current dir to sys.path', current_dir) sys.path.append(current_dir) device = 'cuda:2' from sparktts.models.audio_tokenizer import BiCodecTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM import soundfile as sf audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device) print(audio_tokenizer) tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True) print(tokenizer) print(model) model = model.bfloat16().to(device) model.eval() prompt_text = "我们并不是通过物理移动手段找到星河的。" prompt_audio_file = os.path.join(current_dir, 'kafka.wav') prompt_audio, sampling_rate = sf.read(prompt_audio_file) print(f"Loaded prompt audio from {prompt_audio_file}") print(f"Original sampling rate: {sampling_rate}Hz") print(f"Audio shape: {prompt_audio.shape}") target_sample_rate = audio_tokenizer.config['sample_rate'] if sampling_rate != target_sample_rate: print(f"Resampling from {sampling_rate}Hz to {target_sample_rate}Hz...") from librosa import resample prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate) prompt_audio = np.array(prompt_audio, dtype=np.float32) print(f"Resampled audio shape: {prompt_audio.shape}") else: print(f"Audio sampling rate already matches target ({target_sample_rate}Hz)") texts = ["为了点燃青少年对科技的热情,培养他们的创新思维与动手能力,杏花岭区巨轮街道社区教育学校携手中车社区教育分校,与太原市科学技术协会联手,于暑期精心策划了一场别开生面的青少年数智技术服务港探索之旅,吸引了众多社区青少年的积极参与。"] eos_token_id = model.config.vocab_size - 1 print(f"EOS token ID: {eos_token_id}") # 生成输入嵌入 embeddings,attention_mask = generate_embeddings_batch( model=model, tokenizer=tokenizer, texts=texts, bicodec=audio_tokenizer, prompt_text=prompt_text, prompt_audio=prompt_audio ) input_embs = embeddings['input_embs'] print(f"input_embs shape: {input_embs.shape}") print(f"attention_mask shape: {attention_mask.shape}") print(f"input_embs dtype: {input_embs.dtype}") print(f"attention_mask dtype: {attention_mask.dtype}") print(f"input_embs: {input_embs}") print(f"attention_mask: {attention_mask}") print(f"input_embs: {input_embs}") with torch.no_grad(): generate(model, input_embs, attention_mask, new_max_tokens=1000, top_k=25, top_p=0.95, temperate=1.0, eos_token_id=eos_token_id, pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, past_key_values=None) with torch.no_grad(): audio_tokens,past_key_values = generate(model, input_embs, attention_mask, new_max_tokens=1000, top_k=50, top_p=0.8, temperate=1.0, eos_token_id=eos_token_id, pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, past_key_values=None) audio_tokens = torch.tensor(audio_tokens, dtype=torch.long, device=device) audio_tokens = audio_tokens[:,:-1] print(f"audio_tokens: {audio_tokens}") print(f"past_key_values: {past_key_values}") global_tokens = embeddings['global_tokens'] print(f"global_tokens shape: {global_tokens.shape}") print(f"audio_tokens shape: {audio_tokens.shape}") with torch.no_grad(): wav = audio_tokenizer.detokenize(global_tokens, audio_tokens) print(f"wav shape: {wav.shape}") sf.write('test.wav', wav, audio_tokenizer.config['sample_rate'])