|
import os |
|
import json |
|
import torch |
|
import numpy as np |
|
from transformers import AutoTokenizer |
|
from properties_util import convert_standard_properties_to_tokens |
|
|
|
def print_properties_info(age: str, gender: str, emotion: str, pitch: float, speed: float): |
|
""" |
|
打印属性信息的辅助函数 |
|
|
|
Args: |
|
age: 年龄 |
|
gender: 性别 |
|
emotion: 情感 |
|
pitch: 音调 |
|
speed: 速度 |
|
""" |
|
print(f'age: {age}, gender: {gender}, emotion: {emotion}, pitch: {pitch}, speed: {speed}') |
|
|
|
@torch.inference_mode() |
|
def extract_embeddings_for_global_tokens(model, tokenizer, text, age: str, gender: str, emotion: str, pitch: float, speed: float,global_tokens: list = None): |
|
""" |
|
提取生成全局tokens所需的embedding |
|
|
|
Args: |
|
model: 模型实例 |
|
tokenizer: 分词器 |
|
text: 输入文本 |
|
age: 年龄 |
|
gender: 性别 |
|
emotion: 情感 |
|
pitch: 音调 |
|
speed: 速度 |
|
global_tokens: 全局tokens |
|
Returns: |
|
torch.Tensor: 拼接后的完整embedding |
|
""" |
|
device = (next(model.parameters()).device) |
|
properties_tokens = convert_standard_properties_to_tokens(age, gender, emotion, pitch, speed) |
|
text_tokens = tokenizer.encode(text, add_special_tokens=False) |
|
properties_tokens = tokenizer.encode(properties_tokens, add_special_tokens=False) |
|
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) |
|
properties_tokens_tensor = torch.tensor(properties_tokens, dtype=torch.long, device=device) |
|
text_embs = model.text_embedder(text_tokens_tensor) |
|
properties_embs = model.text_embedder(properties_tokens_tensor) |
|
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) |
|
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) |
|
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) |
|
full_embs_for_sample = torch.cat([ |
|
properties_embs, |
|
tag_2_emb, text_embs, tag_0_emb, |
|
], dim=0) |
|
if global_tokens is not None: |
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
global_embs = model.global_embedder(global_tokens_tensor) |
|
full_embs_for_sample = torch.cat([ |
|
full_embs_for_sample, |
|
global_embs, |
|
tag_1_emb |
|
], dim=0) |
|
return full_embs_for_sample |
|
|
|
def get_tokenizer(model_dir): |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
|
special_tokens = { |
|
'pad_token': '<|rwkv_tokenizer_end_of_text|>', |
|
'additional_special_tokens': [ |
|
'<|endofprompt|>', |
|
'[breath]', '<strong>', '</strong>', '[noise]', |
|
'[laughter]', '[cough]', '[clucking]', '[accent]', |
|
'[quick_breath]', |
|
"<laughter>", "</laughter>", |
|
"[hissing]", "[sigh]", "[vocalized-noise]", |
|
"[lipsmack]", "[mn]" |
|
] |
|
} |
|
tokenizer.add_special_tokens(special_tokens) |
|
return tokenizer |
|
|
|
def get_respark_tts_tokenizer(model_dir): |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
|
original_vocab_size = tokenizer.vocab_size |
|
added_tokens_file = os.path.join(os.path.dirname(__file__),'spark_tts_added_tokens.json') |
|
with open(added_tokens_file, 'r') as f: |
|
added_tokens = json.load(f) |
|
tokenizer.add_special_tokens(added_tokens) |
|
return tokenizer,original_vocab_size |
|
@torch.inference_mode() |
|
def generate_global_tokens(model, tokenizer, text, age: str, gender: str, emotion: str, pitch: float, speed: float, |
|
num_global_tokens: int = 4096): |
|
full_embs_for_sample = extract_embeddings_for_global_tokens(model, tokenizer, text, age, gender, emotion, pitch, speed) |
|
device = full_embs_for_sample.device |
|
vocab_size = model.config.vocab_size |
|
eos_token_id = vocab_size - 1 |
|
suppress_tokens = [id for id in range(num_global_tokens,vocab_size)] |
|
gen_args = { |
|
"inputs_embeds":full_embs_for_sample.unsqueeze(0), |
|
"attention_mask":torch.ones((1, full_embs_for_sample.shape[1]),dtype=torch.long,device=device), |
|
"max_new_tokens":32, |
|
"min_new_tokens":32, |
|
"do_sample":True, |
|
"top_k":50, |
|
"top_p":0.95, |
|
"temperature":1.0, |
|
"eos_token_id":eos_token_id, |
|
"pad_token_id":tokenizer.pad_token_id, |
|
"use_cache":True, |
|
"suppress_tokens":suppress_tokens, |
|
"return_dict_in_generate":True, |
|
} |
|
generated_outputs = model.generate(**gen_args) |
|
return generated_outputs |
|
@torch.inference_mode() |
|
def generate_input_embeddings(model,tokenizer,text,global_tokens): |
|
device = (next(model.parameters()).device) |
|
text_tokens = tokenizer.encode(text, add_special_tokens=False) |
|
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) |
|
text_embs = model.text_embedder(text_tokens_tensor) |
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
global_embs = model.global_embedder(global_tokens_tensor) |
|
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) |
|
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) |
|
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) |
|
input_embs = torch.cat([tag_2_emb,text_embs,tag_0_emb,global_embs,tag_1_emb],dim=0) |
|
return input_embs |
|
|
|
def generate_embeddings(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None): |
|
""" |
|
为 Spark LLM 生成预测所需的输入嵌入 |
|
|
|
Args: |
|
model: Spark LLM 模型 |
|
tokenizer: 文本分词器 |
|
text: 要生成语音的文本 |
|
bicodec: BiCodecTokenizer 实例 |
|
prompt_text: 提示文本(可选) |
|
prompt_audio: 提示音频数组(可选) |
|
|
|
Returns: |
|
dict: 包含 input_embs 的字典,用于模型预测 |
|
""" |
|
device = next(model.parameters()).device |
|
|
|
|
|
if prompt_audio is not None: |
|
|
|
audio_data = np.array(prompt_audio, dtype=np.float32) |
|
target_sample_rate = bicodec.config['sample_rate'] |
|
|
|
|
|
|
|
|
|
print(f"BiCodecTokenizer 期望的采样率: {target_sample_rate}Hz") |
|
print(f"音频数据形状: {audio_data.shape}") |
|
|
|
|
|
global_tokens, semantic_tokens = bicodec.tokenize(audio_data) |
|
global_tokens = global_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
semantic_tokens = semantic_tokens.squeeze(0).squeeze(0).detach().cpu().tolist() |
|
else: |
|
global_tokens = [] |
|
semantic_tokens = [] |
|
|
|
|
|
if prompt_text is not None: |
|
|
|
full_text = prompt_text + text |
|
|
|
initial_semantic_tokens = semantic_tokens.copy() |
|
else: |
|
full_text = text |
|
initial_semantic_tokens = [] |
|
|
|
|
|
text_tokens = tokenizer.encode(full_text, add_special_tokens=False) |
|
|
|
|
|
text_tokens_tensor = torch.tensor(text_tokens, dtype=torch.long, device=device) |
|
global_tokens_tensor = torch.tensor(global_tokens, dtype=torch.long, device=device) |
|
semantic_tokens_tensor = torch.tensor(initial_semantic_tokens, dtype=torch.long, device=device) |
|
|
|
|
|
text_embs = model.text_embedder(text_tokens_tensor) |
|
global_embs = model.global_embedder(global_tokens_tensor) |
|
semantic_embs = model.model.embeddings(semantic_tokens_tensor) |
|
|
|
|
|
tag_0_emb = model.tts_tag_embedder(torch.tensor([0], dtype=torch.long, device=device)) |
|
tag_1_emb = model.tts_tag_embedder(torch.tensor([1], dtype=torch.long, device=device)) |
|
tag_2_emb = model.tts_tag_embedder(torch.tensor([2], dtype=torch.long, device=device)) |
|
|
|
|
|
input_embs = torch.cat([ |
|
tag_2_emb, |
|
text_embs, |
|
tag_0_emb, |
|
global_embs, |
|
tag_1_emb, |
|
semantic_embs |
|
], dim=0) |
|
|
|
|
|
input_embs = input_embs.unsqueeze(0) |
|
|
|
return { |
|
"input_embs": input_embs, |
|
"global_tokens": global_tokens_tensor, |
|
} |