File size: 10,653 Bytes
b007bca
 
67c46fd
 
 
 
ccdff04
67c46fd
 
ccdff04
67c46fd
b007bca
ccdff04
b007bca
67c46fd
b007bca
 
 
67c46fd
b007bca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67c46fd
 
930f36f
67c46fd
 
 
b007bca
 
67c46fd
 
b007bca
 
 
 
67c46fd
b007bca
 
 
 
 
 
 
 
 
 
67c46fd
b007bca
67c46fd
 
 
ccdff04
67c46fd
 
ccdff04
67c46fd
 
 
 
ccdff04
67c46fd
b007bca
67c46fd
b007bca
 
 
 
51a6224
 
 
 
b007bca
 
 
 
 
 
 
 
 
 
 
67c46fd
ccdff04
 
b007bca
67c46fd
b007bca
67c46fd
 
b007bca
 
 
 
 
 
67c46fd
 
 
ccdff04
 
b007bca
 
67c46fd
1a59bbb
 
 
ccdff04
1a59bbb
 
 
 
 
 
 
b007bca
ccdff04
 
b007bca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67c46fd
 
 
 
 
b007bca
 
 
 
 
 
 
 
 
 
 
67c46fd
 
b007bca
 
 
 
 
 
67c46fd
b007bca
67c46fd
b007bca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccdff04
b007bca
 
 
 
 
 
 
ccdff04
 
 
 
 
 
 
b007bca
 
67c46fd
b007bca
 
 
67c46fd
b007bca
67c46fd
 
ccdff04
67c46fd
 
b007bca
67c46fd
b007bca
67c46fd
b007bca
67c46fd
b007bca
 
 
 
 
 
67c46fd
 
 
 
b007bca
67c46fd
 
 
 
 
ccdff04
67c46fd
 
 
b007bca
67c46fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b007bca
67c46fd
 
 
 
 
 
 
ccdff04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
import base64
from copy import deepcopy
import gradio as gr
import time
from pathlib import Path

from yuewen_api import call_audiochat, call_asr, call_tts

CACHE_DIR = "/tmp/gradio/"
CACHE_CLEAN_AGE = 86400

CHINESE_PROMPT_CONTENT = """你是一个为对话而设计的人工智能模型,目前无法连接到互联网。
当你需要唱歌时,请以(哼唱)开头。当你需要rap或说唱时,请以(RAP)开头。当你需要快速说话时,请以(快速)开头。当你需要慢速说话时,请以(慢速)开头。
现在,你需要倾听用户的语音内容,并以礼貌、简洁、口语化的文本进行回复。你需要尽量用户的语种进行回复。"""

ENGLISH_PROMPT_CONTENT = """You are an AI designed for conversation, currently unable to connect to the internet.
when you need to sing or rap, start your response with (RAP). when you need to speak fast, you start your response with (fast). when you need to speak fast, you start your response with (slow)
Now, you need to listen to the user's voice content and respond with politely, concise, conversational text. Respond in accordance with the user's language."""

PROMPT_TEMPLATE = {
    "English Prompt": ENGLISH_PROMPT_CONTENT,
    "Chinese Prompt": CHINESE_PROMPT_CONTENT,
}


def make_system_prompt(prompt):
    return {"role": "system", "content": prompt}


BASE_DIR = Path(__file__).resolve().parent
CHAT_EXAMPLES = [
    ["Input audio only", None, f"{BASE_DIR}/assets/tell_me_a_short_story_EN.wav"],
    [
        "Input text only",
        "What did the speaker mean when they said, it's raining cats and dogs?",
        None,
    ],
    [
        "Text and audio mixed input",
        "Answer the following query in English",
        f"{BASE_DIR}/assets/request_rap_ZH.wav",
    ],
]


def file_to_base64(filename):
    with open(filename, "rb") as f:
        file_content = f.read()
        return base64.b64encode(file_content).decode("utf-8")


def get_audio_format(audio_filename):
    extension = Path(audio_filename).suffix.lower()
    if extension == ".mp3":
        return "mp3"
    elif extension == ".wav":
        return "wav"
    else:
        return "other"


def add_message(chatbot, history, mic, text):
    if not mic and not text:
        return chatbot, history, "Input is empty"

    content = []

    if text:
        chatbot.append({"role": "user", "content": text})
        content.append({"type": "text", "text": text})
        print(f"add message {text=}")

    if mic and Path(mic).exists():
        chatbot.append({"role": "user", "content": {"path": mic}})
        content.append(
            {
                "type": "input_audio",
                "input_audio": {
                    "data": file_to_base64(mic),
                    "format": get_audio_format(mic),
                },
            }
        )
        print(f"add message {mic=}")

    history.append({"role": "user", "content": content})
    return chatbot, history, None


def get_tmp_audio_path():
    import tempfile

    temp_audio = tempfile.NamedTemporaryFile(dir=CACHE_DIR, delete=False, suffix=".mp3")

    return temp_audio.name


def predict(chatbot, history, user_prompt, enable_asr):
    """Generate a response from the model."""
    start_time = time.time()
    try:
        messages = [make_system_prompt(user_prompt)] + history
        if enable_asr:
            asr_text = None
            if not isinstance(chatbot[-1]["content"], str):
                if isinstance(chatbot[-1]["content"], dict):
                    user_audio_path = chatbot[-1]["content"]["path"]
                else:
                    user_audio_path = chatbot[-1]["content"][0]
                print(f"do asr {user_audio_path=}")
                asr_text = call_asr(user_audio_path)
                print(f"{asr_text=}")
            if asr_text:
                last_input = deepcopy(history[-1])
                last_input["content"].append({"type": "text", "text": asr_text})
                messages = (
                    [make_system_prompt(user_prompt)] + history[:-1] + [last_input]
                )

        text = call_audiochat(messages)
        print(f"predict {text=}")
        audio_path = get_tmp_audio_path()
        call_tts(text, audio_path)
        print(f"save_tmp_audio {audio_path=}")
        chatbot.append({"role": "assistant", "content": text})
        chatbot.append({"role": "assistant", "content": {"path": audio_path}})
        history.append({"role": "assistant", "content": text})
    except Exception as e:
        import traceback

        traceback.print_exc()
        print(f"get a excption: {e=}")
        gr.Warning(f"get a excption: {str(e)}, try regenerate")
    print(f"predict cost {(time.time() - start_time):.2f}s")
    return chatbot, history


def _launch_demo(args):
    with gr.Blocks(delete_cache=(3600, CACHE_CLEAN_AGE)) as demo:
        # 保存 chat 历史,不需要每次再重新拼格式
        history = gr.State([])
        gr.Markdown("""<center><font size=8>Step Audio Chat</center>""")
        gr.Markdown(
            """<font size=4>This preview demonstrates core functionalities. To unlock the cormplete real-time voice conversation system with end-to-end encryption and advanced features, download the [Yuewen APP](https://m.yuewen.cn/call-app) with the link or via QR Code.</font>"""
        )
        with gr.Accordion(label="Click to view the QR code ", open=False):
            gr.Image(
                value="assets/yuewen.jpeg",
                interactive=False,
                show_label=False,
                show_download_button=False,
                show_fullscreen_button=False,
            )
        with gr.Accordion(
            label="The performance of English prompts is not as stable as that of Chinese prompts. You can click here to change sys prompt.",
            open=False,
        ):
            prompt_choice = gr.Radio(
                choices=list(PROMPT_TEMPLATE.keys()),
                value=list(PROMPT_TEMPLATE.keys())[0],
                label="Select prompt template",
            )
            user_prompt = gr.Textbox(
                show_label=False,
                value=list(PROMPT_TEMPLATE.values())[0],
                lines=6,
            )
            prompt_choice.change(
                fn=lambda choice: PROMPT_TEMPLATE[choice],
                inputs=prompt_choice,
                outputs=user_prompt,
            )
        chatbot = gr.Chatbot(
            elem_id="chatbot",
            avatar_images=["assets/user.png", "assets/assistant.png"],
            min_height=800,
            type="messages",
            show_share_button=True,
        )
        mic = gr.Audio(
            label="Only MP3 and WAV formats are supported for audio uploads.",
            sources=["microphone", "upload"],
            type="filepath",
        )
        text = gr.Textbox(
            placeholder="Enter message ...",
            label="Input text message",
            show_label=False,
        )
        with gr.Row():
            enable_asr = gr.Checkbox(
                value=True,
                label="Enhance understanding capability by ASR",
            )
        with gr.Row():
            clear_btn = gr.Button("🧹 Clear History (清除历史)")
            regen_btn = gr.Button("🤔️ Regenerate (重试)")
            gen_btn = gr.Button("🚀 Generate Response")

        example_comment = gr.Textbox(label="Input data type", visible=False)
        example_text = gr.Textbox(label="Input text message", visible=False)
        example_audio = gr.Audio(
            label="Input Audio",
            type="filepath",
            visible=False,
        )

        def update_examples(_, text, mic, user_prompt, enable_asr):
            chatbot = []
            history = []
            chatbot, history, error = add_message(chatbot, history, mic, text)
            if error:
                gr.Warning(error)
                print(f"update_examples error")
                return chatbot, history
            else:
                chatbot, history = predict(chatbot, history, user_prompt, enable_asr)
                print(f"update_examples done")
                return chatbot, history

        with gr.Row():
            gr.Examples(
                fn=update_examples,
                examples=CHAT_EXAMPLES,
                inputs=[
                    example_comment,
                    example_text,
                    example_audio,
                    user_prompt,
                    enable_asr,
                ],
                outputs=[chatbot, history],
                run_on_click=True,
            )

        def on_submit(chatbot, history, mic, text, user_prompt, enable_asr):
            chatbot, history, error = add_message(chatbot, history, mic, text)
            if error:
                gr.Warning(error)
                return chatbot, history, None, None
            else:
                chatbot, history = predict(chatbot, history, user_prompt, enable_asr)
                return chatbot, history, None, None

        gen_btn.click(
            fn=on_submit,
            inputs=[chatbot, history, mic, text, user_prompt, enable_asr],
            outputs=[chatbot, history, mic, text],
            show_progress=True,
        )

        def clear_chat_history():
            return [], []

        clear_btn.click(
            clear_chat_history,
            outputs=[chatbot, history],
            show_progress=True,
        )

        def regenerate(chatbot, history, user_prompt, enable_asr):
            while chatbot and chatbot[-1]["role"] == "assistant":
                chatbot.pop()
            while history and history[-1]["role"] == "assistant":
                print(f"discard {history[-1]}")
                history.pop()
            return predict(chatbot, history, user_prompt, enable_asr)

        regen_btn.click(
            regenerate,
            [chatbot, history, user_prompt, enable_asr],
            [chatbot, history],
            show_progress=True,
        )

    demo.queue().launch(
        share=False,
        server_port=args.server_port,
        server_name=args.server_name,
    )


if __name__ == "__main__":
    from argparse import ArgumentParser
    import os

    parser = ArgumentParser()
    parser.add_argument("--model-path", type=str, help="Tokenizer and TTS model path.")
    parser.add_argument(
        "--server-port", type=int, default=7860, help="Demo server port."
    )
    parser.add_argument(
        "--server-name", type=str, default="0.0.0.0", help="Demo server name."
    )
    args = parser.parse_args()
    _launch_demo(args)