Spaces:
Running
Running
import base64 | |
from copy import deepcopy | |
import gradio as gr | |
import time | |
from pathlib import Path | |
from yuewen_api import call_audiochat, call_asr, call_tts | |
CACHE_DIR = "/tmp/gradio/" | |
CACHE_CLEAN_AGE = 86400 | |
CHINESE_PROMPT_CONTENT = """你是一个为对话而设计的人工智能模型,目前无法连接到互联网。 | |
当你需要唱歌时,请以(哼唱)开头。当你需要rap或说唱时,请以(RAP)开头。当你需要快速说话时,请以(快速)开头。当你需要慢速说话时,请以(慢速)开头。 | |
现在,你需要倾听用户的语音内容,并以礼貌、简洁、口语化的文本进行回复。你需要尽量用户的语种进行回复。""" | |
ENGLISH_PROMPT_CONTENT = """You are an AI designed for conversation, currently unable to connect to the internet. | |
when you need to sing or rap, start your response with (RAP). when you need to speak fast, you start your response with (fast). when you need to speak fast, you start your response with (slow) | |
Now, you need to listen to the user's voice content and respond with politely, concise, conversational text. Respond in accordance with the user's language.""" | |
PROMPT_TEMPLATE = { | |
"English Prompt": ENGLISH_PROMPT_CONTENT, | |
"Chinese Prompt": CHINESE_PROMPT_CONTENT, | |
} | |
def make_system_prompt(prompt): | |
return {"role": "system", "content": prompt} | |
BASE_DIR = Path(__file__).resolve().parent | |
CHAT_EXAMPLES = [ | |
["Input audio only", None, f"{BASE_DIR}/assets/tell_me_a_short_story_EN.wav"], | |
[ | |
"Input text only", | |
"What did the speaker mean when they said, it's raining cats and dogs?", | |
None, | |
], | |
[ | |
"Text and audio mixed input", | |
"Answer the following query in English", | |
f"{BASE_DIR}/assets/request_rap_ZH.wav", | |
], | |
] | |
def file_to_base64(filename): | |
with open(filename, "rb") as f: | |
file_content = f.read() | |
return base64.b64encode(file_content).decode("utf-8") | |
def get_audio_format(audio_filename): | |
extension = Path(audio_filename).suffix.lower() | |
if extension == ".mp3": | |
return "mp3" | |
elif extension == ".wav": | |
return "wav" | |
else: | |
return "other" | |
def add_message(chatbot, history, mic, text): | |
if not mic and not text: | |
return chatbot, history, "Input is empty" | |
content = [] | |
if text: | |
chatbot.append({"role": "user", "content": text}) | |
content.append({"type": "text", "text": text}) | |
print(f"add message {text=}") | |
if mic and Path(mic).exists(): | |
chatbot.append({"role": "user", "content": {"path": mic}}) | |
content.append( | |
{ | |
"type": "input_audio", | |
"input_audio": { | |
"data": file_to_base64(mic), | |
"format": get_audio_format(mic), | |
}, | |
} | |
) | |
print(f"add message {mic=}") | |
history.append({"role": "user", "content": content}) | |
return chatbot, history, None | |
def get_tmp_audio_path(): | |
import tempfile | |
temp_audio = tempfile.NamedTemporaryFile(dir=CACHE_DIR, delete=False, suffix=".mp3") | |
return temp_audio.name | |
def predict(chatbot, history, user_prompt, enable_asr): | |
"""Generate a response from the model.""" | |
start_time = time.time() | |
try: | |
messages = [make_system_prompt(user_prompt)] + history | |
if enable_asr: | |
asr_text = None | |
if not isinstance(chatbot[-1]["content"], str): | |
if isinstance(chatbot[-1]["content"], dict): | |
user_audio_path = chatbot[-1]["content"]["path"] | |
else: | |
user_audio_path = chatbot[-1]["content"][0] | |
print(f"do asr {user_audio_path=}") | |
asr_text = call_asr(user_audio_path) | |
print(f"{asr_text=}") | |
if asr_text: | |
last_input = deepcopy(history[-1]) | |
last_input["content"].append({"type": "text", "text": asr_text}) | |
messages = ( | |
[make_system_prompt(user_prompt)] + history[:-1] + [last_input] | |
) | |
text = call_audiochat(messages) | |
print(f"predict {text=}") | |
audio_path = get_tmp_audio_path() | |
call_tts(text, audio_path) | |
print(f"save_tmp_audio {audio_path=}") | |
chatbot.append({"role": "assistant", "content": text}) | |
chatbot.append({"role": "assistant", "content": {"path": audio_path}}) | |
history.append({"role": "assistant", "content": text}) | |
except Exception as e: | |
import traceback | |
traceback.print_exc() | |
print(f"get a excption: {e=}") | |
gr.Warning(f"get a excption: {str(e)}, try regenerate") | |
print(f"predict cost {(time.time() - start_time):.2f}s") | |
return chatbot, history | |
def _launch_demo(args): | |
with gr.Blocks(delete_cache=(3600, CACHE_CLEAN_AGE)) as demo: | |
# 保存 chat 历史,不需要每次再重新拼格式 | |
history = gr.State([]) | |
gr.Markdown("""<center><font size=8>Step Audio Chat</center>""") | |
gr.Markdown( | |
"""<font size=4>This preview demonstrates core functionalities. To unlock the cormplete real-time voice conversation system with end-to-end encryption and advanced features, download the [Yuewen APP](https://m.yuewen.cn/call-app) with the link or via QR Code.</font>""" | |
) | |
with gr.Accordion(label="Click to view the QR code ", open=False): | |
gr.Image( | |
value="assets/yuewen.jpeg", | |
interactive=False, | |
show_label=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
) | |
with gr.Accordion( | |
label="The performance of English prompts is not as stable as that of Chinese prompts. You can click here to change sys prompt.", | |
open=False, | |
): | |
prompt_choice = gr.Radio( | |
choices=list(PROMPT_TEMPLATE.keys()), | |
value=list(PROMPT_TEMPLATE.keys())[0], | |
label="Select prompt template", | |
) | |
user_prompt = gr.Textbox( | |
show_label=False, | |
value=list(PROMPT_TEMPLATE.values())[0], | |
lines=6, | |
) | |
prompt_choice.change( | |
fn=lambda choice: PROMPT_TEMPLATE[choice], | |
inputs=prompt_choice, | |
outputs=user_prompt, | |
) | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
avatar_images=["assets/user.png", "assets/assistant.png"], | |
min_height=800, | |
type="messages", | |
show_share_button=True, | |
) | |
mic = gr.Audio( | |
label="Only MP3 and WAV formats are supported for audio uploads.", | |
sources=["microphone", "upload"], | |
type="filepath", | |
) | |
text = gr.Textbox( | |
placeholder="Enter message ...", | |
label="Input text message", | |
show_label=False, | |
) | |
with gr.Row(): | |
enable_asr = gr.Checkbox( | |
value=True, | |
label="Enhance understanding capability by ASR", | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("🧹 Clear History (清除历史)") | |
regen_btn = gr.Button("🤔️ Regenerate (重试)") | |
gen_btn = gr.Button("🚀 Generate Response") | |
example_comment = gr.Textbox(label="Input data type", visible=False) | |
example_text = gr.Textbox(label="Input text message", visible=False) | |
example_audio = gr.Audio( | |
label="Input Audio", | |
type="filepath", | |
visible=False, | |
) | |
def update_examples(_, text, mic, user_prompt, enable_asr): | |
chatbot = [] | |
history = [] | |
chatbot, history, error = add_message(chatbot, history, mic, text) | |
if error: | |
gr.Warning(error) | |
print(f"update_examples error") | |
return chatbot, history | |
else: | |
chatbot, history = predict(chatbot, history, user_prompt, enable_asr) | |
print(f"update_examples done") | |
return chatbot, history | |
with gr.Row(): | |
gr.Examples( | |
fn=update_examples, | |
examples=CHAT_EXAMPLES, | |
inputs=[ | |
example_comment, | |
example_text, | |
example_audio, | |
user_prompt, | |
enable_asr, | |
], | |
outputs=[chatbot, history], | |
run_on_click=True, | |
) | |
def on_submit(chatbot, history, mic, text, user_prompt, enable_asr): | |
chatbot, history, error = add_message(chatbot, history, mic, text) | |
if error: | |
gr.Warning(error) | |
return chatbot, history, None, None | |
else: | |
chatbot, history = predict(chatbot, history, user_prompt, enable_asr) | |
return chatbot, history, None, None | |
gen_btn.click( | |
fn=on_submit, | |
inputs=[chatbot, history, mic, text, user_prompt, enable_asr], | |
outputs=[chatbot, history, mic, text], | |
show_progress=True, | |
) | |
def clear_chat_history(): | |
return [], [] | |
clear_btn.click( | |
clear_chat_history, | |
outputs=[chatbot, history], | |
show_progress=True, | |
) | |
def regenerate(chatbot, history, user_prompt, enable_asr): | |
while chatbot and chatbot[-1]["role"] == "assistant": | |
chatbot.pop() | |
while history and history[-1]["role"] == "assistant": | |
print(f"discard {history[-1]}") | |
history.pop() | |
return predict(chatbot, history, user_prompt, enable_asr) | |
regen_btn.click( | |
regenerate, | |
[chatbot, history, user_prompt, enable_asr], | |
[chatbot, history], | |
show_progress=True, | |
) | |
demo.queue().launch( | |
share=False, | |
server_port=args.server_port, | |
server_name=args.server_name, | |
) | |
if __name__ == "__main__": | |
from argparse import ArgumentParser | |
import os | |
parser = ArgumentParser() | |
parser.add_argument("--model-path", type=str, help="Tokenizer and TTS model path.") | |
parser.add_argument( | |
"--server-port", type=int, default=7860, help="Demo server port." | |
) | |
parser.add_argument( | |
"--server-name", type=str, default="0.0.0.0", help="Demo server name." | |
) | |
args = parser.parse_args() | |
_launch_demo(args) | |