beyoru commited on
Commit
9680c53
·
verified ·
1 Parent(s): dc22316

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -80
app.py CHANGED
@@ -1,92 +1,84 @@
1
- import os
2
  import gradio as gr
3
- from argparse import ArgumentParser
4
- import copy
5
- import tempfile
6
  import requests
7
- from http import HTTPStatus
8
- from dashscope import MultiModalConversation
9
 
10
- # Set environment variables and API key
11
- API_KEY = os.environ['API_KEY']
12
- dashscope.api_key = API_KEY
 
 
13
 
14
- # Define constants
15
- MODEL_NAME = 'Qwen2-VL-2B-Instruct'
16
-
17
- # Get arguments
18
- def _get_args():
19
- parser = ArgumentParser()
20
- parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link.")
21
- parser.add_argument("--server-port", type=int, default=7860, help="Server port.")
22
- parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name.")
23
- return parser.parse_args()
24
-
25
- # Simplify chat prediction
26
- def predict(_chatbot, task_history, system_prompt):
27
- chat_query = _chatbot[-1][0]
28
- query = task_history[-1][0]
29
- if not chat_query:
30
- _chatbot.pop()
31
- task_history.pop()
32
- return _chatbot
33
- print("User:", query)
34
- history_cp = copy.deepcopy(task_history)
35
- messages = [{'role': 'user', 'content': [{'text': q}]} for q, _ in history_cp]
36
- responses = MultiModalConversation.call(
37
- model=MODEL_NAME, messages=messages, stream=True,
38
  )
39
- for response in responses:
40
- if not response.status_code == HTTPStatus.OK:
41
- raise Exception(f'Error: {response.message}')
42
- response_text = ''.join([ele['text'] for ele in response.output.choices[0].message.content])
43
- _chatbot[-1] = (chat_query, response_text)
44
- yield _chatbot
45
-
46
- # Add text to history
47
- def add_text(history, task_history, text):
48
- task_text = text
49
- history.append((_parse_text(text), None))
50
- task_history.append((task_text, None))
51
- return history, task_history, ""
52
-
53
- # Reset input
54
- def reset_user_input():
55
- return gr.update(value="")
56
 
57
- # Reset history
58
- def reset_state(task_history):
59
- task_history.clear()
60
- return []
 
 
 
 
 
 
 
61
 
62
- # Launch the demo
63
- def _launch_demo(args):
64
- chatbot = gr.Chatbot(label='Qwen2-VL-2B-Instruct', height=500)
65
- query = gr.Textbox(lines=2, label='Input')
66
- system_prompt = gr.Textbox(lines=2, label='System Prompt', placeholder="Modify system prompt here...")
67
- task_history = gr.State([])
 
 
 
 
 
 
 
 
68
 
69
- with gr.Row():
70
- submit_btn = gr.Button("🚀 Submit")
71
- regen_btn = gr.Button("🤔️ Regenerate")
72
- empty_bin = gr.Button("🧹 Clear History")
 
 
 
73
 
74
- submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
75
- predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
- submit_btn.click(reset_user_input, [], [query])
78
- empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
79
- regen_btn.click(predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True)
80
-
81
- gr.Markdown("""<center><font size=3>Qwen2-VL-2B-Instruct Demo</center>""")
82
- gr.Markdown("""<center><font size=2>Note: This demo uses Qwen2-VL-2B-Instruct model. Please be mindful of ethical content creation.</center>""")
83
-
84
- demo.queue().launch(share=args.share, server_port=args.server_port, server_name=args.server_name)
85
 
86
- # Main function
87
- def main():
88
- args = _get_args()
89
- _launch_demo(args)
90
 
91
- if __name__ == '__main__':
92
- main()
 
 
 
 
1
  import gradio as gr
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
  import requests
 
 
6
 
7
+ # Load the model and processor
8
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
9
+ "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto",
10
+ )
11
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
12
 
13
+ # Process text and image for inference
14
+ def generate_response(messages: list):
15
+ # Preprocess conversation (text + image)
16
+ text_prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
17
+
18
+ # Prepare input tensors
19
+ images = [msg.get("image") for msg in messages if msg.get("image")]
20
+ text = [text_prompt]
21
+
22
+ inputs = processor(
23
+ text=text, images=images, padding=True, return_tensors="pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Inference: Generate the output
27
+ output_ids = model.generate(**inputs, max_new_tokens=128)
28
+ generated_ids = [
29
+ output_ids[len(input_ids) :]
30
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
31
+ ]
32
+ output_text = processor.batch_decode(
33
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
34
+ )
35
+
36
+ return output_text[0]
37
 
38
+ # Gradio chat interface function
39
+ def chat_interface(user_input, image: Image = None, history=[]):
40
+ # Add user input to the history
41
+ if image:
42
+ message = {
43
+ "role": "user",
44
+ "content": [{"type": "image", "image": image}, {"type": "text", "text": user_input}],
45
+ }
46
+ else:
47
+ message = {
48
+ "role": "user",
49
+ "content": [{"type": "text", "text": user_input}],
50
+ }
51
+ history.append(message)
52
 
53
+ # Get model response
54
+ response = generate_response(history)
55
+
56
+ # Add model response to the history
57
+ history.append({"role": "assistant", "content": [{"type": "text", "text": response}]})
58
+
59
+ return history, response
60
 
61
+ # Gradio chat interface setup
62
+ def create_gradio_interface():
63
+ # Chat interface with image upload and text input
64
+ interface = gr.Interface(
65
+ fn=chat_interface,
66
+ inputs=[
67
+ gr.Textbox(type="text", label="Your Message"),
68
+ gr.Image(type="pil", label="Upload an Image", optional=True)
69
+ ],
70
+ outputs=[
71
+ gr.Chatbot(label="Chatbot"),
72
+ gr.Textbox(label="Model's Response")
73
+ ],
74
+ title="Chat with Vision Model",
75
+ description="This is a multimodal model where you can chat with it using both images and text inputs. The model will respond accordingly based on your input.",
76
+ allow_flagging="never"
77
  )
 
 
 
 
 
 
 
 
78
 
79
+ return interface
 
 
 
80
 
81
+ # Run the Gradio app
82
+ if __name__ == "__main__":
83
+ interface = create_gradio_interface()
84
+ interface.launch()