respark / trained_190k_steps /test_rtf_batch.py
yueyulin's picture
Upload folder using huggingface_hub
fc99023 verified
import os
import sys
import time
import numpy as np
import soundfile as sf
from collections import defaultdict
import json
from datetime import datetime
current_dir = os.path.dirname(os.path.abspath(__file__))
print('add current dir to sys.path', current_dir)
sys.path.append(current_dir)
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from utilities import generate_embeddings_batch
def calculate_rtf(audio_length_seconds, processing_time_seconds):
"""
计算RTF (Real-Time Factor)
RTF = 处理时间 / 音频长度
RTF < 1 表示实时处理,RTF > 1 表示处理时间超过音频长度
"""
return processing_time_seconds / audio_length_seconds
def generate_speech_batch_with_timing(model, tokenizer, texts, bicodec, prompt_text=None, prompt_audio=None,
max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
temperature=1.0, device="cuda:0"):
"""
带时间测量的批量语音生成函数
Returns:
tuple: (音频波形列表, 处理时间, 音频长度列表)
"""
import torch
# 设置eos_token_id
eos_token_id = model.config.vocab_size - 1
# 生成输入嵌入
embeddings, attention_mask = generate_embeddings_batch(
model=model,
tokenizer=tokenizer,
texts=texts,
bicodec=bicodec,
prompt_text=prompt_text,
prompt_audio=prompt_audio
)
batch_size = len(texts)
global_tokens = embeddings['global_tokens']
# 设置模型为评估模式
model.eval()
# 开始计时
start_time = time.time()
with torch.no_grad():
# 使用模型的generate方法进行批量生成
generated_outputs = model.generate(
inputs_embeds=embeddings['input_embs'],
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
use_cache=True
)
# 处理每个样本的生成结果
wavs = []
eos_index = torch.where(generated_outputs == eos_token_id)[1]
for i in range(batch_size):
# 获取当前样本的生成结果
sample_outputs = generated_outputs[i]
# 找到第一个eos_token_id
eos_token_id_index = eos_index[i]
sample_outputs = sample_outputs[:eos_token_id_index]
# 使用BiCodec解码生成音频
with torch.no_grad():
wav = bicodec.detokenize(global_tokens[i:i+1], sample_outputs.unsqueeze(0))
wavs.append(wav)
# 结束计时
end_time = time.time()
processing_time = end_time - start_time
# 计算每个音频的长度(秒)
sample_rate = bicodec.config['sample_rate']
audio_lengths = [len(wav) / sample_rate for wav in wavs]
return wavs, processing_time, audio_lengths
def generate_speech_with_timing(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
temperature=1.0, device="cuda:0"):
"""
带时间测量的单次语音生成函数
Returns:
tuple: (音频波形, 处理时间, 音频长度)
"""
import torch
# 设置eos_token_id
eos_token_id = model.config.vocab_size - 1
# 生成输入嵌入
embeddings, attention_mask = generate_embeddings_batch(
model=model,
tokenizer=tokenizer,
texts=[text],
bicodec=bicodec,
prompt_text=prompt_text,
prompt_audio=prompt_audio
)
batch_size = 1
global_tokens = embeddings['global_tokens']
# 设置模型为评估模式
model.eval()
# 开始计时
start_time = time.time()
with torch.no_grad():
# 使用模型的generate方法进行生成
generated_outputs = model.generate(
inputs_embeds=embeddings['input_embs'],
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
use_cache=True
)
# 处理生成结果
sample_outputs = generated_outputs[0]
eos_index = torch.where(generated_outputs[0] == eos_token_id)[0]
if len(eos_index) > 0:
sample_outputs = sample_outputs[:eos_index[0]]
# 使用BiCodec解码生成音频
with torch.no_grad():
wav = bicodec.detokenize(global_tokens, sample_outputs.unsqueeze(0))
# 结束计时
end_time = time.time()
processing_time = end_time - start_time
# 计算音频长度(秒)
sample_rate = bicodec.config['sample_rate']
audio_length = len(wav) / sample_rate
return wav, processing_time, audio_length
def warmup_model(model, tokenizer, audio_tokenizer, prompt_audio=None, device="cuda:0", warmup_count=3):
"""
模型预热函数,进行几次不计算时间的生成
Args:
model: 语言模型
tokenizer: 文本分词器
audio_tokenizer: BiCodecTokenizer实例
prompt_audio: 提示音频(可选)
device: 设备
warmup_count: 预热次数
Returns:
None
"""
import torch
print(f"开始模型预热,进行 {warmup_count} 次生成...")
# 预热用的简单文本
warmup_texts = [
"你好,这是一个预热测试。",
"人工智能技术正在快速发展。",
"语音合成技术将文本转换为自然的语音输出。"
]
for i in range(warmup_count):
print(f" 预热 {i+1}/{warmup_count}")
# 选择预热文本
warmup_text = warmup_texts[i % len(warmup_texts)]
try:
# 设置eos_token_id
eos_token_id = model.config.vocab_size - 1
# 生成输入嵌入
embeddings, attention_mask = generate_embeddings_batch(
model=model,
tokenizer=tokenizer,
texts=[warmup_text],
bicodec=audio_tokenizer,
prompt_text=None,
prompt_audio=prompt_audio
)
global_tokens = embeddings['global_tokens']
# 设置模型为评估模式
model.eval()
with torch.no_grad():
# 使用模型的generate方法进行生成
generated_outputs = model.generate(
inputs_embeds=embeddings['input_embs'],
attention_mask=attention_mask,
max_new_tokens=1000, # 预热时使用较少的token
do_sample=True,
top_k=50,
top_p=0.95,
temperature=1.0,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
use_cache=True
)
# 处理生成结果
sample_outputs = generated_outputs[0]
eos_index = torch.where(generated_outputs[0] == eos_token_id)[0]
if len(eos_index) > 0:
sample_outputs = sample_outputs[:eos_index[0]]
# 使用BiCodec解码生成音频(不保存)
with torch.no_grad():
wav = audio_tokenizer.detokenize(global_tokens, sample_outputs.unsqueeze(0))
print(f" 预热完成,生成音频长度: {len(wav) / audio_tokenizer.config['sample_rate']:.2f}s")
except Exception as e:
print(f" 预热错误: {str(e)}")
print("模型预热完成!")
print("-" * 40)
def run_rtf_batch_test(texts, model, tokenizer, audio_tokenizer, prompt_audio=None,
device="cuda:0", output_dir="rtf_test_results", warmup_count=3, batch_size=4):
"""
运行批量RTF测试
Args:
texts: 要测试的文本列表
model: 语言模型
tokenizer: 文本分词器
audio_tokenizer: BiCodecTokenizer实例
prompt_audio: 提示音频(可选)
device: 设备
output_dir: 输出目录
warmup_count: 预热次数
batch_size: 批量大小
Returns:
dict: 测试结果统计
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 首先进行模型预热
warmup_model(model, tokenizer, audio_tokenizer, prompt_audio, device, warmup_count)
# 测试结果存储
results = []
total_processing_time = 0
total_audio_length = 0
print(f"开始批量RTF测试,共 {len(texts)} 个文本,批量大小: {batch_size}...")
print("=" * 80)
# 将文本分批处理
for batch_start in range(0, len(texts), batch_size):
batch_end = min(batch_start + batch_size, len(texts))
batch_texts = texts[batch_start:batch_end]
batch_num = batch_start // batch_size + 1
total_batches = (len(texts) + batch_size - 1) // batch_size
print(f"\n处理批次 {batch_num}/{total_batches} (文本 {batch_start+1}-{batch_end})")
print(f"批次文本数量: {len(batch_texts)}")
try:
# 批量生成语音并计时
wavs, processing_time, audio_lengths = generate_speech_batch_with_timing(
model=model,
tokenizer=tokenizer,
texts=batch_texts,
bicodec=audio_tokenizer,
prompt_audio=prompt_audio,
device=device
)
# 计算批次的总音频时长
batch_total_audio_length = sum(audio_lengths)
# 计算批次的RTF:处理时间 / 总音频时长
batch_rtf = calculate_rtf(batch_total_audio_length, processing_time)
# 处理每个生成的音频
for i, (wav, audio_length) in enumerate(zip(wavs, audio_lengths)):
text_index = batch_start + i
text = batch_texts[i]
# 保存音频文件
output_filename = os.path.join(output_dir, f"test_{text_index+1:03d}.wav")
sf.write(output_filename, wav, audio_tokenizer.config['sample_rate'])
# 记录结果
result = {
"index": text_index + 1,
"batch": batch_num,
"text": text,
"batch_processing_time": processing_time, # 整个批次的处理时间
"audio_length": audio_length,
"batch_rtf": batch_rtf, # 整个批次的RTF
"output_file": output_filename
}
results.append(result)
print(f" 文本 {text_index+1}: 音频长度 {audio_length:.3f}s")
print(f" 批次总处理时间: {processing_time:.3f}s")
print(f" 批次总音频时长: {batch_total_audio_length:.3f}s")
print(f" 批次RTF: {batch_rtf:.3f}")
# 累加到总体统计
total_processing_time += processing_time
total_audio_length += batch_total_audio_length
except Exception as e:
print(f" 批次错误: {str(e)}")
# 记录失败的批次
for i, text in enumerate(batch_texts):
text_index = batch_start + i
result = {
"index": text_index + 1,
"batch": batch_num,
"text": text,
"error": str(e)
}
results.append(result)
# 计算总体统计
successful_results = [r for r in results if "error" not in r]
if successful_results:
# 计算批次级别的统计
batch_rtfs = [r["batch_rtf"] for r in successful_results]
batch_processing_times = [r["batch_processing_time"] for r in successful_results]
avg_audio_length = np.mean([r["audio_length"] for r in successful_results])
total_rtf = calculate_rtf(total_audio_length, total_processing_time)
stats = {
"total_tests": len(texts),
"successful_tests": len(successful_results),
"failed_tests": len(texts) - len(successful_results),
"batch_size": batch_size,
"total_batches": total_batches,
"total_processing_time": total_processing_time,
"total_audio_length": total_audio_length,
"total_rtf": total_rtf,
"avg_batch_rtf": np.mean(batch_rtfs),
"avg_batch_processing_time": np.mean(batch_processing_times),
"avg_audio_length": avg_audio_length,
"min_batch_rtf": min(batch_rtfs),
"max_batch_rtf": max(batch_rtfs),
"std_batch_rtf": np.std(batch_rtfs)
}
else:
stats = {
"total_tests": len(texts),
"successful_tests": 0,
"failed_tests": len(texts),
"batch_size": batch_size,
"total_batches": total_batches,
"error": "所有测试都失败了"
}
# 保存详细结果到JSON文件
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = os.path.join(output_dir, f"rtf_test_results_{timestamp}.json")
output_data = {
"test_info": {
"timestamp": timestamp,
"device": device,
"model_path": current_dir,
"batch_size": batch_size
},
"statistics": stats,
"detailed_results": results
}
with open(results_file, 'w', encoding='utf-8') as f:
json.dump(output_data, f, ensure_ascii=False, indent=2)
# 打印统计结果
print("\n" + "=" * 80)
print("RTF测试统计结果:")
print("=" * 80)
print(f"总测试数: {stats['total_tests']}")
print(f"成功测试数: {stats['successful_tests']}")
print(f"失败测试数: {stats['failed_tests']}")
print(f"批量大小: {batch_size}")
print(f"总批次数: {total_batches}")
if successful_results:
print(f"总处理时间: {stats['total_processing_time']:.3f}s")
print(f"总音频长度: {stats['total_audio_length']:.3f}s")
print(f"总体RTF: {stats['total_rtf']:.3f}")
print(f"平均批次RTF: {stats['avg_batch_rtf']:.3f}")
print(f"平均批次处理时间: {stats['avg_batch_processing_time']:.3f}s")
print(f"平均音频长度: {stats['avg_audio_length']:.3f}s")
print(f"最小批次RTF: {stats['min_batch_rtf']:.3f}")
print(f"最大批次RTF: {stats['max_batch_rtf']:.3f}")
print(f"批次RTF标准差: {stats['std_batch_rtf']:.3f}")
print(f"\n详细结果已保存到: {results_file}")
print(f"音频文件保存在: {output_dir}")
return stats, results
if __name__ == "__main__":
import torch
device = 'cuda:2'
# 初始化模型和分词器
print("正在加载模型和分词器...")
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
model = model.bfloat16().to(device)
model.eval()
# 加载提示音频(可选)
prompt_audio = None
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
if os.path.exists(prompt_audio_file):
print(f"加载提示音频: {prompt_audio_file}")
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
print(f"重采样从 {sampling_rate}Hz 到 {target_sample_rate}Hz...")
from librosa import resample
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
prompt_audio = np.array(prompt_audio, dtype=np.float32)
# 测试文本列表
test_texts = [
"一九五二年二月十日,志愿军大英雄张积慧击落美军双料王牌飞行员戴维斯,在自己飞机坠毁处距离戴维斯坠机处不足五百米的情况下,取得了世界空战史不可能复制的奇迹。伟大的张积慧。",
"在数字浪潮汹涌的今天,数智技术正以前所未有的力量重塑着社会的每一个角落。",
"为了点燃青少年对科技的热情,培养他们的创新思维与动手能力",
"杏花岭区巨轮街道社区教育学校携手中车社区教育分校,与太原市科学技术协会联手,于暑期精心策划了一场别开生面的青少年数智技术服务港探索之旅,吸引了众多社区青少年的积极参与。",
"一踏入数智技术服务港的大门,一股浓厚的科技气息便扑面而来。",
"科普课堂上,“简易红绿灯”科学实验更是将抽象的电路原理与日常生活紧密相连。",
"实验开始前,老师生动地介绍了实验物品,并引导青少年思考红绿灯的工作原理,激发了他们浓厚的探索兴趣。",
"在老师的指导下,青少年们开始动手组装电路,将红绿灯的各个部件连接起来。",
"他们小心翼翼地调整电路,确保每个部件都正确连接,红灯、绿灯、黄灯依次亮起,仿佛在讲述一个关于交通规则的故事。",
"实验过程中,青少年们不仅学到了电路知识,还体验到了动手实践的乐趣。",
"他们纷纷表示,这次实验不仅让他们对科技有了更深的理解,还培养了他们的创新思维和动手能力。",
"数智技术服务港,让科技触手可及,让创新无处不在。",
"人工智能技术正在快速发展,为各行各业带来了革命性的变化。",
"深度学习模型在语音识别、图像处理、自然语言处理等领域取得了突破性进展。",
"机器学习算法能够从大量数据中学习模式,并做出准确的预测和决策。",
"神经网络模拟人脑的工作方式,通过多层神经元处理复杂的信息。",
"计算机视觉技术让机器能够理解和分析图像内容。",
"自然语言处理技术使计算机能够理解和生成人类语言。",
"语音合成技术将文本转换为自然的语音输出。",
"大数据分析帮助企业发现隐藏的模式和趋势。"
]
# 运行RTF测试
stats, results = run_rtf_batch_test(
texts=test_texts,
model=model,
tokenizer=tokenizer,
audio_tokenizer=audio_tokenizer,
prompt_audio=prompt_audio,
device=device,
output_dir="rtf_test_results",
warmup_count=1, # 预热3次
batch_size=8
)