File size: 3,256 Bytes
33d7d3a 7c5c8ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
import gradio as gr
from argparse import ArgumentParser
import copy
import tempfile
import requests
from http import HTTPStatus
from dashscope import MultiModalConversation
# Set environment variables and API key
API_KEY = os.environ['API_KEY']
dashscope.api_key = API_KEY
# Define constants
MODEL_NAME = 'Qwen2-VL-2B-Instruct'
# Get arguments
def _get_args():
parser = ArgumentParser()
parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link.")
parser.add_argument("--server-port", type=int, default=7860, help="Server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name.")
return parser.parse_args()
# Simplify chat prediction
def predict(_chatbot, task_history, system_prompt):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if not chat_query:
_chatbot.pop()
task_history.pop()
return _chatbot
print("User:", query)
history_cp = copy.deepcopy(task_history)
messages = [{'role': 'user', 'content': [{'text': q}]} for q, _ in history_cp]
responses = MultiModalConversation.call(
model=MODEL_NAME, messages=messages, stream=True,
)
for response in responses:
if not response.status_code == HTTPStatus.OK:
raise Exception(f'Error: {response.message}')
response_text = ''.join([ele['text'] for ele in response.output.choices[0].message.content])
_chatbot[-1] = (chat_query, response_text)
yield _chatbot
# Add text to history
def add_text(history, task_history, text):
task_text = text
history.append((_parse_text(text), None))
task_history.append((task_text, None))
return history, task_history, ""
# Reset input
def reset_user_input():
return gr.update(value="")
# Reset history
def reset_state(task_history):
task_history.clear()
return []
# Launch the demo
def _launch_demo(args):
chatbot = gr.Chatbot(label='Qwen2-VL-2B-Instruct', height=500)
query = gr.Textbox(lines=2, label='Input')
system_prompt = gr.Textbox(lines=2, label='System Prompt', placeholder="Modify system prompt here...")
task_history = gr.State([])
with gr.Row():
submit_btn = gr.Button("🚀 Submit")
regen_btn = gr.Button("🤔️ Regenerate")
empty_bin = gr.Button("🧹 Clear History")
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True
)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
regen_btn.click(predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True)
gr.Markdown("""<center><font size=3>Qwen2-VL-2B-Instruct Demo</center>""")
gr.Markdown("""<center><font size=2>Note: This demo uses Qwen2-VL-2B-Instruct model. Please be mindful of ethical content creation.</center>""")
demo.queue().launch(share=args.share, server_port=args.server_port, server_name=args.server_name)
# Main function
def main():
args = _get_args()
_launch_demo(args)
if __name__ == '__main__':
main()
|