import os import sys import time import torch import soundfile as sf import numpy as np from librosa import resample # Add current directory to sys.path to find local modules current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) from sparktts.models.audio_tokenizer import BiCodecTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM from utilities import generate_embeddings def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None, max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95, temperature=1.0, device="cuda:0"): """ Generates speech from text and returns timing for each major step. """ timings = {} # --- 1. Generate Embeddings --- t0 = time.perf_counter() eos_token_id = model.config.vocab_size - 1 embeddings = generate_embeddings( model=model, tokenizer=tokenizer, text=text, bicodec=bicodec, prompt_text=prompt_text, prompt_audio=prompt_audio ) torch.cuda.synchronize() t1 = time.perf_counter() timings['embedding_generation'] = t1 - t0 # --- 2. LLM Inference --- global_tokens = embeddings['global_tokens'].unsqueeze(0) model.eval() with torch.no_grad(): generated_outputs = model.generate( inputs_embeds=embeddings['input_embs'], attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]), dtype=torch.long, device=device), max_new_tokens=max_new_tokens, do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, eos_token_id=eos_token_id, pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, use_cache=True ) torch.cuda.synchronize() t2 = time.perf_counter() timings['llm_inference'] = t2 - t1 # --- 3. Detokenization --- semantic_tokens_tensor = generated_outputs[:,:-1] token_size = semantic_tokens_tensor.shape[1] print(f"Token size: {token_size} tokens per second = {token_size / (t2 - t1)}") with torch.no_grad(): wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor) torch.cuda.synchronize() t3 = time.perf_counter() timings['detokenization'] = t3 - t2 return wav, timings def main(): device = 'cuda:2' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") # --- Model Loading --- print("Loading models and tokenizers...") audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device) tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True) model = model.bfloat16().to(device) model.eval() model = torch.compile(model) print("Models and tokenizers loaded.") # --- Prompt Loading --- prompt_audio_file = os.path.join(current_dir, 'kafka.wav') prompt_audio, sampling_rate = sf.read(prompt_audio_file) target_sample_rate = audio_tokenizer.config['sample_rate'] if sampling_rate != target_sample_rate: prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate) prompt_audio = np.array(prompt_audio, dtype=np.float32) text_to_synthesize = "科学技术是第一生产力,最近AI的迅猛发展让我们看到了迈向星辰大海的希望。" # --- Warm-up Run --- print("\n--- Starting warm-up run (not timed) ---") _, _ = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer, prompt_audio=prompt_audio, device=device) print("--- Warm-up finished ---\n") # --- Benchmarking --- num_iterations = 100 total_generation_time = 0 total_audio_duration = 0 total_timings = {'embedding_generation': 0, 'llm_inference': 0, 'detokenization': 0} print(f"--- Starting benchmark: {num_iterations} iterations ---") for i in range(num_iterations): start_time = time.perf_counter() wav, timings = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer, prompt_audio=prompt_audio, device=device) end_time = time.perf_counter() generation_time = end_time - start_time audio_duration = len(wav) / target_sample_rate total_generation_time += generation_time total_audio_duration += audio_duration for key in total_timings: total_timings[key] += timings[key] timing_details = f"Embed: {timings['embedding_generation']:.4f}s, LLM: {timings['llm_inference']:.4f}s, Decode: {timings['detokenization']:.4f}s" print(f"Iteration {i+1}/{num_iterations}: Total: {generation_time:.4f}s, Audio: {audio_duration:.4f}s | {timing_details}") # --- Results --- if total_audio_duration > 0: rtf = total_generation_time / total_audio_duration else: rtf = float('inf') print("\n--- Benchmark Results ---") print(f"Total iterations: {num_iterations}") print(f"Total generation time: {total_generation_time:.4f} seconds") print(f"Total audio duration: {total_audio_duration:.4f} seconds") print(f"Average generation time: {total_generation_time / num_iterations:.4f} seconds") print(f"Real-Time Factor (RTF): {rtf:.4f}") print("-------------------------") # --- Detailed Timings --- print("\n--- Detailed Timing Breakdown ---") avg_total_gen_time = total_generation_time / num_iterations for name, total_time in total_timings.items(): avg_time = total_time / num_iterations percentage = (avg_time / avg_total_gen_time) * 100 if avg_total_gen_time > 0 else 0 print(f"Average {name}: {avg_time:.4f}s ({percentage:.2f}%)") print("---------------------------------") if __name__ == "__main__": main()