File size: 3,256 Bytes
33d7d3a
7c5c8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import gradio as gr
from argparse import ArgumentParser
import copy
import tempfile
import requests
from http import HTTPStatus
from dashscope import MultiModalConversation

# Set environment variables and API key
API_KEY = os.environ['API_KEY']
dashscope.api_key = API_KEY

# Define constants
MODEL_NAME = 'Qwen2-VL-2B-Instruct'

# Get arguments
def _get_args():
    parser = ArgumentParser()
    parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link.")
    parser.add_argument("--server-port", type=int, default=7860, help="Server port.")
    parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name.")
    return parser.parse_args()

# Simplify chat prediction
def predict(_chatbot, task_history, system_prompt):
    chat_query = _chatbot[-1][0]
    query = task_history[-1][0]
    if not chat_query:
        _chatbot.pop()
        task_history.pop()
        return _chatbot
    print("User:", query)
    history_cp = copy.deepcopy(task_history)
    messages = [{'role': 'user', 'content': [{'text': q}]} for q, _ in history_cp]
    responses = MultiModalConversation.call(
        model=MODEL_NAME, messages=messages, stream=True,
    )
    for response in responses:
        if not response.status_code == HTTPStatus.OK:
            raise Exception(f'Error: {response.message}')
        response_text = ''.join([ele['text'] for ele in response.output.choices[0].message.content])
        _chatbot[-1] = (chat_query, response_text)
        yield _chatbot

# Add text to history
def add_text(history, task_history, text):
    task_text = text
    history.append((_parse_text(text), None))
    task_history.append((task_text, None))
    return history, task_history, ""

# Reset input
def reset_user_input():
    return gr.update(value="")

# Reset history
def reset_state(task_history):
    task_history.clear()
    return []

# Launch the demo
def _launch_demo(args):
    chatbot = gr.Chatbot(label='Qwen2-VL-2B-Instruct', height=500)
    query = gr.Textbox(lines=2, label='Input')
    system_prompt = gr.Textbox(lines=2, label='System Prompt', placeholder="Modify system prompt here...")
    task_history = gr.State([])

    with gr.Row():
        submit_btn = gr.Button("🚀 Submit")
        regen_btn = gr.Button("🤔️ Regenerate")
        empty_bin = gr.Button("🧹 Clear History")

    submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
        predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True
    )
    submit_btn.click(reset_user_input, [], [query])
    empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
    regen_btn.click(predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True)

    gr.Markdown("""<center><font size=3>Qwen2-VL-2B-Instruct Demo</center>""")
    gr.Markdown("""<center><font size=2>Note: This demo uses Qwen2-VL-2B-Instruct model. Please be mindful of ethical content creation.</center>""")
    
    demo.queue().launch(share=args.share, server_port=args.server_port, server_name=args.server_name)

# Main function
def main():
    args = _get_args()
    _launch_demo(args)

if __name__ == '__main__':
    main()