|
import os |
|
import gradio as gr |
|
from argparse import ArgumentParser |
|
import copy |
|
import tempfile |
|
import requests |
|
from http import HTTPStatus |
|
from dashscope import MultiModalConversation |
|
|
|
|
|
API_KEY = os.environ['API_KEY'] |
|
dashscope.api_key = API_KEY |
|
|
|
|
|
MODEL_NAME = 'Qwen2-VL-2B-Instruct' |
|
|
|
|
|
def _get_args(): |
|
parser = ArgumentParser() |
|
parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link.") |
|
parser.add_argument("--server-port", type=int, default=7860, help="Server port.") |
|
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name.") |
|
return parser.parse_args() |
|
|
|
|
|
def predict(_chatbot, task_history, system_prompt): |
|
chat_query = _chatbot[-1][0] |
|
query = task_history[-1][0] |
|
if not chat_query: |
|
_chatbot.pop() |
|
task_history.pop() |
|
return _chatbot |
|
print("User:", query) |
|
history_cp = copy.deepcopy(task_history) |
|
messages = [{'role': 'user', 'content': [{'text': q}]} for q, _ in history_cp] |
|
responses = MultiModalConversation.call( |
|
model=MODEL_NAME, messages=messages, stream=True, |
|
) |
|
for response in responses: |
|
if not response.status_code == HTTPStatus.OK: |
|
raise Exception(f'Error: {response.message}') |
|
response_text = ''.join([ele['text'] for ele in response.output.choices[0].message.content]) |
|
_chatbot[-1] = (chat_query, response_text) |
|
yield _chatbot |
|
|
|
|
|
def add_text(history, task_history, text): |
|
task_text = text |
|
history.append((_parse_text(text), None)) |
|
task_history.append((task_text, None)) |
|
return history, task_history, "" |
|
|
|
|
|
def reset_user_input(): |
|
return gr.update(value="") |
|
|
|
|
|
def reset_state(task_history): |
|
task_history.clear() |
|
return [] |
|
|
|
|
|
def _launch_demo(args): |
|
chatbot = gr.Chatbot(label='Qwen2-VL-2B-Instruct', height=500) |
|
query = gr.Textbox(lines=2, label='Input') |
|
system_prompt = gr.Textbox(lines=2, label='System Prompt', placeholder="Modify system prompt here...") |
|
task_history = gr.State([]) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("🚀 Submit") |
|
regen_btn = gr.Button("🤔️ Regenerate") |
|
empty_bin = gr.Button("🧹 Clear History") |
|
|
|
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then( |
|
predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True |
|
) |
|
submit_btn.click(reset_user_input, [], [query]) |
|
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) |
|
regen_btn.click(predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True) |
|
|
|
gr.Markdown("""<center><font size=3>Qwen2-VL-2B-Instruct Demo</center>""") |
|
gr.Markdown("""<center><font size=2>Note: This demo uses Qwen2-VL-2B-Instruct model. Please be mindful of ethical content creation.</center>""") |
|
|
|
demo.queue().launch(share=args.share, server_port=args.server_port, server_name=args.server_name) |
|
|
|
|
|
def main(): |
|
args = _get_args() |
|
_launch_demo(args) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|