qew / app.py
beyoru's picture
Update app.py
7c5c8ba verified
raw
history blame
3.26 kB
import os
import gradio as gr
from argparse import ArgumentParser
import copy
import tempfile
import requests
from http import HTTPStatus
from dashscope import MultiModalConversation
# Set environment variables and API key
API_KEY = os.environ['API_KEY']
dashscope.api_key = API_KEY
# Define constants
MODEL_NAME = 'Qwen2-VL-2B-Instruct'
# Get arguments
def _get_args():
parser = ArgumentParser()
parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link.")
parser.add_argument("--server-port", type=int, default=7860, help="Server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name.")
return parser.parse_args()
# Simplify chat prediction
def predict(_chatbot, task_history, system_prompt):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if not chat_query:
_chatbot.pop()
task_history.pop()
return _chatbot
print("User:", query)
history_cp = copy.deepcopy(task_history)
messages = [{'role': 'user', 'content': [{'text': q}]} for q, _ in history_cp]
responses = MultiModalConversation.call(
model=MODEL_NAME, messages=messages, stream=True,
)
for response in responses:
if not response.status_code == HTTPStatus.OK:
raise Exception(f'Error: {response.message}')
response_text = ''.join([ele['text'] for ele in response.output.choices[0].message.content])
_chatbot[-1] = (chat_query, response_text)
yield _chatbot
# Add text to history
def add_text(history, task_history, text):
task_text = text
history.append((_parse_text(text), None))
task_history.append((task_text, None))
return history, task_history, ""
# Reset input
def reset_user_input():
return gr.update(value="")
# Reset history
def reset_state(task_history):
task_history.clear()
return []
# Launch the demo
def _launch_demo(args):
chatbot = gr.Chatbot(label='Qwen2-VL-2B-Instruct', height=500)
query = gr.Textbox(lines=2, label='Input')
system_prompt = gr.Textbox(lines=2, label='System Prompt', placeholder="Modify system prompt here...")
task_history = gr.State([])
with gr.Row():
submit_btn = gr.Button("🚀 Submit")
regen_btn = gr.Button("🤔️ Regenerate")
empty_bin = gr.Button("🧹 Clear History")
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True
)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
regen_btn.click(predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True)
gr.Markdown("""<center><font size=3>Qwen2-VL-2B-Instruct Demo</center>""")
gr.Markdown("""<center><font size=2>Note: This demo uses Qwen2-VL-2B-Instruct model. Please be mindful of ethical content creation.</center>""")
demo.queue().launch(share=args.share, server_port=args.server_port, server_name=args.server_name)
# Main function
def main():
args = _get_args()
_launch_demo(args)
if __name__ == '__main__':
main()