File size: 6,055 Bytes
ee3b868 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import sys
import time
import torch
import soundfile as sf
import numpy as np
from librosa import resample
# Add current directory to sys.path to find local modules
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from utilities import generate_embeddings
def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
temperature=1.0, device="cuda:0"):
"""
Generates speech from text and returns timing for each major step.
"""
timings = {}
# --- 1. Generate Embeddings ---
t0 = time.perf_counter()
eos_token_id = model.config.vocab_size - 1
embeddings = generate_embeddings(
model=model,
tokenizer=tokenizer,
text=text,
bicodec=bicodec,
prompt_text=prompt_text,
prompt_audio=prompt_audio
)
torch.cuda.synchronize()
t1 = time.perf_counter()
timings['embedding_generation'] = t1 - t0
# --- 2. LLM Inference ---
global_tokens = embeddings['global_tokens'].unsqueeze(0)
model.eval()
with torch.no_grad():
generated_outputs = model.generate(
inputs_embeds=embeddings['input_embs'],
attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]), dtype=torch.long, device=device),
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
use_cache=True
)
torch.cuda.synchronize()
t2 = time.perf_counter()
timings['llm_inference'] = t2 - t1
# --- 3. Detokenization ---
semantic_tokens_tensor = generated_outputs[:,:-1]
token_size = semantic_tokens_tensor.shape[1]
print(f"Token size: {token_size} tokens per second = {token_size / (t2 - t1)}")
with torch.no_grad():
wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
torch.cuda.synchronize()
t3 = time.perf_counter()
timings['detokenization'] = t3 - t2
return wav, timings
def main():
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# --- Model Loading ---
print("Loading models and tokenizers...")
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
model = model.bfloat16().to(device)
model.eval()
model = torch.compile(model)
print("Models and tokenizers loaded.")
# --- Prompt Loading ---
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
prompt_audio = np.array(prompt_audio, dtype=np.float32)
text_to_synthesize = "科学技术是第一生产力,最近AI的迅猛发展让我们看到了迈向星辰大海的希望。"
# --- Warm-up Run ---
print("\n--- Starting warm-up run (not timed) ---")
_, _ = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer,
prompt_audio=prompt_audio, device=device)
print("--- Warm-up finished ---\n")
# --- Benchmarking ---
num_iterations = 100
total_generation_time = 0
total_audio_duration = 0
total_timings = {'embedding_generation': 0, 'llm_inference': 0, 'detokenization': 0}
print(f"--- Starting benchmark: {num_iterations} iterations ---")
for i in range(num_iterations):
start_time = time.perf_counter()
wav, timings = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer,
prompt_audio=prompt_audio, device=device)
end_time = time.perf_counter()
generation_time = end_time - start_time
audio_duration = len(wav) / target_sample_rate
total_generation_time += generation_time
total_audio_duration += audio_duration
for key in total_timings:
total_timings[key] += timings[key]
timing_details = f"Embed: {timings['embedding_generation']:.4f}s, LLM: {timings['llm_inference']:.4f}s, Decode: {timings['detokenization']:.4f}s"
print(f"Iteration {i+1}/{num_iterations}: Total: {generation_time:.4f}s, Audio: {audio_duration:.4f}s | {timing_details}")
# --- Results ---
if total_audio_duration > 0:
rtf = total_generation_time / total_audio_duration
else:
rtf = float('inf')
print("\n--- Benchmark Results ---")
print(f"Total iterations: {num_iterations}")
print(f"Total generation time: {total_generation_time:.4f} seconds")
print(f"Total audio duration: {total_audio_duration:.4f} seconds")
print(f"Average generation time: {total_generation_time / num_iterations:.4f} seconds")
print(f"Real-Time Factor (RTF): {rtf:.4f}")
print("-------------------------")
# --- Detailed Timings ---
print("\n--- Detailed Timing Breakdown ---")
avg_total_gen_time = total_generation_time / num_iterations
for name, total_time in total_timings.items():
avg_time = total_time / num_iterations
percentage = (avg_time / avg_total_gen_time) * 100 if avg_total_gen_time > 0 else 0
print(f"Average {name}: {avg_time:.4f}s ({percentage:.2f}%)")
print("---------------------------------")
if __name__ == "__main__":
main() |