File size: 6,055 Bytes
ee3b868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import sys
import time
import torch
import soundfile as sf
import numpy as np
from librosa import resample

# Add current directory to sys.path to find local modules
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)

from sparktts.models.audio_tokenizer import BiCodecTokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from utilities import generate_embeddings

def generate_speech(model, tokenizer, text, bicodec, prompt_text=None, prompt_audio=None, 
                   max_new_tokens=3000, do_sample=True, top_k=50, top_p=0.95, 
                   temperature=1.0, device="cuda:0"):
    """
    Generates speech from text and returns timing for each major step.
    """
    timings = {}

    # --- 1. Generate Embeddings ---
    t0 = time.perf_counter()
    eos_token_id = model.config.vocab_size - 1
    embeddings = generate_embeddings(
        model=model,
        tokenizer=tokenizer,
        text=text,
        bicodec=bicodec,
        prompt_text=prompt_text,
        prompt_audio=prompt_audio
    )
    torch.cuda.synchronize()
    t1 = time.perf_counter()
    timings['embedding_generation'] = t1 - t0

    # --- 2. LLM Inference ---
    global_tokens = embeddings['global_tokens'].unsqueeze(0)
    model.eval()
    with torch.no_grad():
        generated_outputs = model.generate(
            inputs_embeds=embeddings['input_embs'],
            attention_mask=torch.ones((1, embeddings['input_embs'].shape[1]), dtype=torch.long, device=device),
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
            eos_token_id=eos_token_id,
            pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
            use_cache=True
        )
    torch.cuda.synchronize()
    t2 = time.perf_counter()
    timings['llm_inference'] = t2 - t1

    # --- 3. Detokenization ---
    semantic_tokens_tensor = generated_outputs[:,:-1]
    token_size = semantic_tokens_tensor.shape[1]
    print(f"Token size: {token_size} tokens per second = {token_size / (t2 - t1)}")
    with torch.no_grad():
        wav = bicodec.detokenize(global_tokens, semantic_tokens_tensor)
    torch.cuda.synchronize()
    t3 = time.perf_counter()
    timings['detokenization'] = t3 - t2
    
    return wav, timings

def main():
    device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    # --- Model Loading ---
    print("Loading models and tokenizers...")
    audio_tokenizer = BiCodecTokenizer(model_dir=current_dir, device=device)
    tokenizer = AutoTokenizer.from_pretrained(current_dir, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(current_dir, trust_remote_code=True)
    
    model = model.bfloat16().to(device)
    model.eval()
    model = torch.compile(model)
    print("Models and tokenizers loaded.")

    # --- Prompt Loading ---
    prompt_audio_file = os.path.join(current_dir, 'kafka.wav')
    prompt_audio, sampling_rate = sf.read(prompt_audio_file)
    
    target_sample_rate = audio_tokenizer.config['sample_rate']
    if sampling_rate != target_sample_rate:
        prompt_audio = resample(prompt_audio, orig_sr=sampling_rate, target_sr=target_sample_rate)
        prompt_audio = np.array(prompt_audio, dtype=np.float32)
    
    text_to_synthesize = "科学技术是第一生产力,最近AI的迅猛发展让我们看到了迈向星辰大海的希望。"

    # --- Warm-up Run ---
    print("\n--- Starting warm-up run (not timed) ---")
    _, _ = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer, 
                           prompt_audio=prompt_audio, device=device)
    print("--- Warm-up finished ---\n")

    # --- Benchmarking ---
    num_iterations = 100
    total_generation_time = 0
    total_audio_duration = 0
    total_timings = {'embedding_generation': 0, 'llm_inference': 0, 'detokenization': 0}

    print(f"--- Starting benchmark: {num_iterations} iterations ---")
    for i in range(num_iterations):
        start_time = time.perf_counter()
        
        wav, timings = generate_speech(model, tokenizer, text_to_synthesize, audio_tokenizer, 
                                       prompt_audio=prompt_audio, device=device)
        
        end_time = time.perf_counter()

        generation_time = end_time - start_time
        audio_duration = len(wav) / target_sample_rate
        
        total_generation_time += generation_time
        total_audio_duration += audio_duration
        for key in total_timings:
            total_timings[key] += timings[key]
        
        timing_details = f"Embed: {timings['embedding_generation']:.4f}s, LLM: {timings['llm_inference']:.4f}s, Decode: {timings['detokenization']:.4f}s"
        print(f"Iteration {i+1}/{num_iterations}: Total: {generation_time:.4f}s, Audio: {audio_duration:.4f}s | {timing_details}")

    # --- Results ---
    if total_audio_duration > 0:
        rtf = total_generation_time / total_audio_duration
    else:
        rtf = float('inf')

    print("\n--- Benchmark Results ---")
    print(f"Total iterations: {num_iterations}")
    print(f"Total generation time: {total_generation_time:.4f} seconds")
    print(f"Total audio duration: {total_audio_duration:.4f} seconds")
    print(f"Average generation time: {total_generation_time / num_iterations:.4f} seconds")
    print(f"Real-Time Factor (RTF): {rtf:.4f}")
    print("-------------------------")

    # --- Detailed Timings ---
    print("\n--- Detailed Timing Breakdown ---")
    avg_total_gen_time = total_generation_time / num_iterations
    for name, total_time in total_timings.items():
        avg_time = total_time / num_iterations
        percentage = (avg_time / avg_total_gen_time) * 100 if avg_total_gen_time > 0 else 0
        print(f"Average {name}: {avg_time:.4f}s ({percentage:.2f}%)")
    print("---------------------------------")



if __name__ == "__main__":
    main()