beyoru commited on
Commit
7c5c8ba
·
verified ·
1 Parent(s): dcf00ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -148
app.py CHANGED
@@ -1,149 +1,92 @@
1
- import gradio as gr
2
- import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- import torch
6
- from PIL import Image
7
- import subprocess
8
- import numpy as np
9
  import os
10
- from threading import Thread
11
- import uuid
12
- import io
13
-
14
- # Model and Processor Loading
15
- MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
16
- model = Qwen2VLForConditionalGeneration.from_pretrained(
17
- MODEL_ID,
18
- trust_remote_code=True,
19
- torch_dtype=torch.float16
20
- ).eval()
21
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
22
-
23
- DESCRIPTION = "[Qwen2-VL-2B Demo](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)"
24
-
25
- image_extensions = Image.registered_extensions()
26
- video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
27
-
28
-
29
- def identify_and_save_blob(blob_path):
30
- """Identifies if the blob is an image or video and saves it accordingly."""
31
- try:
32
- with open(blob_path, 'rb') as file:
33
- blob_content = file.read()
34
-
35
- try:
36
- Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
37
- extension = ".png" # Default to PNG for saving
38
- media_type = "image"
39
- except (IOError, SyntaxError):
40
- extension = ".mp4" # Default to MP4 for saving
41
- media_type = "video"
42
-
43
- filename = f"temp_{uuid.uuid4()}_media{extension}"
44
- with open(filename, "wb") as f:
45
- f.write(blob_content)
46
- return filename, media_type
47
-
48
- except Exception as e:
49
- raise ValueError(f"Error processing the file: {e}")
50
-
51
-
52
- @spaces.GPU
53
- def qwen_inference(media_input, text_input=None, system_prompt=None, max_tokens=1024):
54
- try:
55
- media_type = None # Initialize media_type variable
56
-
57
- if isinstance(media_input, str):
58
- media_path = media_input
59
- if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
60
- media_type = "image"
61
- elif media_path.endswith(video_extensions):
62
- media_type = "video"
63
- else:
64
- # Handle the case where file format is unknown
65
- media_path, media_type = identify_and_save_blob(media_input)
66
-
67
- if not media_type: # Check if media_type was assigned properly
68
- raise ValueError("Unsupported media type. Please upload an image or video.")
69
-
70
- # Default system prompt if none is provided
71
- system_prompt = system_prompt or "You are a helpful assistant. Answer questions based on the image or video provided, and explain your reasoning clearly."
72
-
73
- messages = [
74
- {
75
- "role": "system",
76
- "content": system_prompt
77
- },
78
- {
79
- "role": "user",
80
- "content": [
81
- {
82
- "type": media_type,
83
- media_type: media_path,
84
- **({"fps": 8.0} if media_type == "video" else {}),
85
- },
86
- {"type": "text", "text": text_input},
87
- ],
88
- }
89
- ]
90
-
91
- text = processor.apply_chat_template(
92
- messages, tokenize=False, add_generation_prompt=True
93
- )
94
- image_inputs, video_inputs = process_vision_info(messages)
95
- inputs = processor(
96
- text=[text],
97
- images=image_inputs,
98
- videos=video_inputs,
99
- padding=True,
100
- return_tensors="pt",
101
- )
102
- streamer = TextIteratorStreamer(
103
- processor, skip_prompt=True, **{"skip_special_tokens": True}
104
- )
105
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
106
-
107
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
108
- thread.start()
109
-
110
- buffer = ""
111
- for new_text in streamer:
112
- buffer += new_text
113
- yield buffer
114
-
115
- except Exception as e:
116
- yield f"Error during inference: {e}"
117
-
118
-
119
- css = """
120
- #output {
121
- height: 500px;
122
- overflow: auto;
123
- border: 1px solid #ccc;
124
- }
125
- """
126
-
127
- with gr.Blocks(css=css) as demo:
128
- gr.Markdown(DESCRIPTION)
129
-
130
- with gr.Tab(label="Image/Video Input"):
131
- with gr.Row():
132
- with gr.Column():
133
- input_media = gr.File(label="Upload Image or Video", type="filepath")
134
- system_prompt = gr.Textbox(
135
- label="System Prompt",
136
- value="You are a helpful assistant. Answer questions based on the image or video provided, and explain your reasoning clearly.",
137
- lines=3
138
- )
139
- text_input = gr.Textbox(label="Question")
140
- max_tokens = gr.Slider(label="Max New Tokens", minimum=16, maximum=2048, value=1024, step=16)
141
- submit_btn = gr.Button(value="Submit")
142
- with gr.Column():
143
- output_text = gr.Textbox(label="Output Text", elem_id="output")
144
-
145
- submit_btn.click(
146
- qwen_inference, [input_media, text_input, system_prompt, max_tokens], [output_text]
147
- )
148
-
149
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import gradio as gr
3
+ from argparse import ArgumentParser
4
+ import copy
5
+ import tempfile
6
+ import requests
7
+ from http import HTTPStatus
8
+ from dashscope import MultiModalConversation
9
+
10
+ # Set environment variables and API key
11
+ API_KEY = os.environ['API_KEY']
12
+ dashscope.api_key = API_KEY
13
+
14
+ # Define constants
15
+ MODEL_NAME = 'Qwen2-VL-2B-Instruct'
16
+
17
+ # Get arguments
18
+ def _get_args():
19
+ parser = ArgumentParser()
20
+ parser.add_argument("--share", action="store_true", default=False, help="Create a publicly shareable link.")
21
+ parser.add_argument("--server-port", type=int, default=7860, help="Server port.")
22
+ parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name.")
23
+ return parser.parse_args()
24
+
25
+ # Simplify chat prediction
26
+ def predict(_chatbot, task_history, system_prompt):
27
+ chat_query = _chatbot[-1][0]
28
+ query = task_history[-1][0]
29
+ if not chat_query:
30
+ _chatbot.pop()
31
+ task_history.pop()
32
+ return _chatbot
33
+ print("User:", query)
34
+ history_cp = copy.deepcopy(task_history)
35
+ messages = [{'role': 'user', 'content': [{'text': q}]} for q, _ in history_cp]
36
+ responses = MultiModalConversation.call(
37
+ model=MODEL_NAME, messages=messages, stream=True,
38
+ )
39
+ for response in responses:
40
+ if not response.status_code == HTTPStatus.OK:
41
+ raise Exception(f'Error: {response.message}')
42
+ response_text = ''.join([ele['text'] for ele in response.output.choices[0].message.content])
43
+ _chatbot[-1] = (chat_query, response_text)
44
+ yield _chatbot
45
+
46
+ # Add text to history
47
+ def add_text(history, task_history, text):
48
+ task_text = text
49
+ history.append((_parse_text(text), None))
50
+ task_history.append((task_text, None))
51
+ return history, task_history, ""
52
+
53
+ # Reset input
54
+ def reset_user_input():
55
+ return gr.update(value="")
56
+
57
+ # Reset history
58
+ def reset_state(task_history):
59
+ task_history.clear()
60
+ return []
61
+
62
+ # Launch the demo
63
+ def _launch_demo(args):
64
+ chatbot = gr.Chatbot(label='Qwen2-VL-2B-Instruct', height=500)
65
+ query = gr.Textbox(lines=2, label='Input')
66
+ system_prompt = gr.Textbox(lines=2, label='System Prompt', placeholder="Modify system prompt here...")
67
+ task_history = gr.State([])
68
+
69
+ with gr.Row():
70
+ submit_btn = gr.Button("🚀 Submit")
71
+ regen_btn = gr.Button("🤔️ Regenerate")
72
+ empty_bin = gr.Button("🧹 Clear History")
73
+
74
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
75
+ predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True
76
+ )
77
+ submit_btn.click(reset_user_input, [], [query])
78
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
79
+ regen_btn.click(predict, [chatbot, task_history, system_prompt], [chatbot], show_progress=True)
80
+
81
+ gr.Markdown("""<center><font size=3>Qwen2-VL-2B-Instruct Demo</center>""")
82
+ gr.Markdown("""<center><font size=2>Note: This demo uses Qwen2-VL-2B-Instruct model. Please be mindful of ethical content creation.</center>""")
83
+
84
+ demo.queue().launch(share=args.share, server_port=args.server_port, server_name=args.server_name)
85
+
86
+ # Main function
87
+ def main():
88
+ args = _get_args()
89
+ _launch_demo(args)
90
+
91
+ if __name__ == '__main__':
92
+ main()