yueyulin's picture
Upload folder using huggingface_hub
ee3b868 verified
import os
import sys
import time
import torch
import soundfile as sf
import numpy as np
from librosa import resample
# Add current directory to sys.path to find local modules
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from utilities import generate_embeddings
def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None,
max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95,
temperature=1.0, device="cuda:0"):
"""
Generates speech from text and returns timing for each major step.
"""
timings = {}
# --- 1. Generate Embeddings ---
t0 = time.perf_counter()
eos_token_id = model.config.vocab_size - 1
embeddings = generate_embeddings(
model=model,
tokenizer=tokenizer,
text=text,
bicodec=bicodec,
prompt_text=prompt_text,
prompt_audio=prompt_audio
)
torch.cuda.synchronize()
t1 = time.perf_counter()
timings['embedding_generation'] = t1 - t0
# --- 2. LLM Inference ---
global_tokens = embeddings['global_tokens'].unsqueeze(0)
model.eval()
with torch.no_grad():
generated_outputs = model.generate(
inputs_embeds=embeddings['input_embs'],
attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]), dtype=torch.long, device=device),
max_new_tokens=max_new_tokens,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
eos_token_id=eos_token_id,
pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
use_cache=True
)
torch.cuda.synchronize()
t2 = time.perf_counter()
timings['llm_inference'] = t2 - t1
# --- 3. Detokenization ---
semantic_tokens_tensor = generated_outputs[:,:-1]
token_size = semantic_tokens_tensor.shape[1]
print(f"Token size: {token_size} tokens per second = {token_size / (t2 - t1)}")
with torch.no_grad():
wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
torch.cuda.synchronize()
t3 = time.perf_counter()
timings['detokenization'] = t3 - t2
return wav, timings
def main():
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# --- Model Loading ---
print("Loading models and tokenizers...")
audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
model = model.bfloat16().to(device)
model.eval()
model = torch.compile(model)
print("Models and tokenizers loaded.")
# --- Prompt Loading ---
prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
prompt_audio, sampling_rate = sf.read(prompt_audio_file)
target_sample_rate = audio_tokenizer.config['sample_rate']
if sampling_rate != target_sample_rate:
prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
prompt_audio = np.array(prompt_audio, dtype=np.float32)
text_to_synthesize = "科学技术是第一生产力,最近AI的迅猛发展让我们看到了迈向星辰大海的希望。"
# --- Warm-up Run ---
print("\n--- Starting warm-up run (not timed) ---")
_, _ = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer,
prompt_audio=prompt_audio, device=device)
print("--- Warm-up finished ---\n")
# --- Benchmarking ---
num_iterations = 100
total_generation_time = 0
total_audio_duration = 0
total_timings = {'embedding_generation': 0, 'llm_inference': 0, 'detokenization': 0}
print(f"--- Starting benchmark: {num_iterations} iterations ---")
for i in range(num_iterations):
start_time = time.perf_counter()
wav, timings = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer,
prompt_audio=prompt_audio, device=device)
end_time = time.perf_counter()
generation_time = end_time - start_time
audio_duration = len(wav) / target_sample_rate
total_generation_time += generation_time
total_audio_duration += audio_duration
for key in total_timings:
total_timings[key] += timings[key]
timing_details = f"Embed: {timings['embedding_generation']:.4f}s, LLM: {timings['llm_inference']:.4f}s, Decode: {timings['detokenization']:.4f}s"
print(f"Iteration {i+1}/{num_iterations}: Total: {generation_time:.4f}s, Audio: {audio_duration:.4f}s | {timing_details}")
# --- Results ---
if total_audio_duration > 0:
rtf = total_generation_time / total_audio_duration
else:
rtf = float('inf')
print("\n--- Benchmark Results ---")
print(f"Total iterations: {num_iterations}")
print(f"Total generation time: {total_generation_time:.4f} seconds")
print(f"Total audio duration: {total_audio_duration:.4f} seconds")
print(f"Average generation time: {total_generation_time / num_iterations:.4f} seconds")
print(f"Real-Time Factor (RTF): {rtf:.4f}")
print("-------------------------")
# --- Detailed Timings ---
print("\n--- Detailed Timing Breakdown ---")
avg_total_gen_time = total_generation_time / num_iterations
for name, total_time in total_timings.items():
avg_time = total_time / num_iterations
percentage = (avg_time / avg_total_gen_time) * 100 if avg_total_gen_time > 0 else 0
print(f"Average {name}: {avg_time:.4f}s ({percentage:.2f}%)")
print("---------------------------------")
if __name__ == "__main__":
main()