Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,14 +9,16 @@ import random
|
|
9 |
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
|
10 |
from threading import Thread
|
11 |
import re
|
12 |
-
import time
|
13 |
import torch
|
14 |
import cv2
|
15 |
from gradio_client import Client, file
|
16 |
|
|
|
17 |
def image_gen(prompt):
|
18 |
client = Client("KingNish/Image-Gen-Pro")
|
19 |
-
return client.predict("Image Generation",None, prompt, api_name="/image_gen_pro")
|
|
|
20 |
|
21 |
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
|
22 |
|
@@ -28,27 +30,29 @@ model.to("cpu")
|
|
28 |
|
29 |
def llava(message, history):
|
30 |
if message["files"]:
|
31 |
-
image = message["files"][0]
|
32 |
else:
|
33 |
for hist in history:
|
34 |
-
if type(hist[0])==tuple:
|
35 |
image = hist[0][0]
|
36 |
-
|
37 |
txt = message["text"]
|
38 |
-
|
39 |
gr.Info("Analyzing image")
|
40 |
image = Image.open(image).convert("RGB")
|
41 |
prompt = f"<|im_start|>user <image>\n{txt}<|im_start|>assistant"
|
42 |
-
|
43 |
inputs = processor(prompt, image, return_tensors="pt")
|
44 |
return inputs
|
45 |
|
|
|
46 |
def extract_text_from_webpage(html_content):
|
47 |
soup = BeautifulSoup(html_content, 'html.parser')
|
48 |
for tag in soup(["script", "style", "header", "footer"]):
|
49 |
tag.extract()
|
50 |
return soup.get_text(strip=True)
|
51 |
|
|
|
52 |
def search(query):
|
53 |
term = query
|
54 |
start = 0
|
@@ -69,7 +73,9 @@ def search(query):
|
|
69 |
link = result.find("a", href=True)
|
70 |
link = link["href"]
|
71 |
try:
|
72 |
-
webpage = session.get(link, headers={
|
|
|
|
|
73 |
webpage.raise_for_status()
|
74 |
visible_text = extract_text_from_webpage(webpage.text)
|
75 |
if len(visible_text) > max_chars_per_page:
|
@@ -79,12 +85,14 @@ def search(query):
|
|
79 |
all_results.append({"link": link, "text": None})
|
80 |
return all_results
|
81 |
|
|
|
82 |
# Initialize inference clients for different models
|
83 |
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
|
84 |
client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
|
85 |
client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
|
86 |
client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat")
|
87 |
|
|
|
88 |
# Define the main chat function
|
89 |
def respond(message, history):
|
90 |
func_caller = []
|
@@ -98,17 +106,31 @@ def respond(message, history):
|
|
98 |
|
99 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
100 |
thread.start()
|
101 |
-
|
102 |
buffer = ""
|
103 |
for new_text in streamer:
|
104 |
buffer += new_text
|
105 |
yield buffer
|
106 |
else:
|
107 |
functions_metadata = [
|
108 |
-
{"type": "function", "function": {"name": "web_search", "description": "Search query on google",
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
]
|
113 |
|
114 |
for msg in history:
|
@@ -116,20 +138,21 @@ def respond(message, history):
|
|
116 |
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
|
117 |
|
118 |
message_text = message["text"]
|
119 |
-
func_caller.append({"role": "user",
|
120 |
-
|
|
|
121 |
response = client_gemma.chat_completion(func_caller, max_tokens=200)
|
122 |
response = str(response)
|
123 |
try:
|
124 |
response = response[int(response.find("{")):int(response.rindex("</"))]
|
125 |
except:
|
126 |
-
response = response[int(response.find("{")):(int(response.rfind("}"))+1)]
|
127 |
response = response.replace("\\n", "")
|
128 |
response = response.replace("\\'", "'")
|
129 |
response = response.replace('\\"', '"')
|
130 |
response = response.replace('\\', '')
|
131 |
print(f"\n{response}")
|
132 |
-
|
133 |
try:
|
134 |
json_data = json.loads(str(response))
|
135 |
if json_data["name"] == "web_search":
|
@@ -142,11 +165,12 @@ def respond(message, history):
|
|
142 |
for msg in history:
|
143 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
144 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
145 |
-
messages+=f"\n<|im_start|>user\n{message_text}
|
146 |
-
stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True,
|
|
|
147 |
output = ""
|
148 |
for response in stream:
|
149 |
-
if not response.token.text == "
|
150 |
output += response.token.text
|
151 |
yield output
|
152 |
elif json_data["name"] == "image_generation":
|
@@ -158,7 +182,7 @@ def respond(message, history):
|
|
158 |
yield gr.Image(image[1])
|
159 |
except:
|
160 |
client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
|
161 |
-
seed = random.randint(0,999999)
|
162 |
image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
|
163 |
yield gr.Image(image)
|
164 |
elif json_data["name"] == "image_qna":
|
@@ -168,7 +192,7 @@ def respond(message, history):
|
|
168 |
|
169 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
170 |
thread.start()
|
171 |
-
|
172 |
buffer = ""
|
173 |
for new_text in streamer:
|
174 |
buffer += new_text
|
@@ -178,8 +202,9 @@ def respond(message, history):
|
|
178 |
for msg in history:
|
179 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
180 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
181 |
-
messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
182 |
-
stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True,
|
|
|
183 |
output = ""
|
184 |
for response in stream:
|
185 |
if not response.token.text == "<|endoftext|>":
|
@@ -190,19 +215,21 @@ def respond(message, history):
|
|
190 |
for msg in history:
|
191 |
messages += f"\n<|start_header_id|>user\n{str(msg[0])}<|end_header_id|>"
|
192 |
messages += f"\n<|start_header_id|>assistant\n{str(msg[1])}<|end_header_id|>"
|
193 |
-
messages+=f"\n<|start_header_id|>user\n{message_text}<|end_header_id|>\n<|start_header_id|>assistant\n"
|
194 |
-
stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True,
|
|
|
195 |
output = ""
|
196 |
for response in stream:
|
197 |
if not response.token.text == "<|eot_id|>":
|
198 |
output += response.token.text
|
199 |
yield output
|
200 |
|
|
|
201 |
# Create the Gradio interface
|
202 |
demo = gr.ChatInterface(
|
203 |
fn=respond,
|
204 |
chatbot=gr.Chatbot(show_copy_button=True, likeable=True, layout="panel"),
|
205 |
-
description
|
206 |
textbox=gr.MultimodalTextbox(),
|
207 |
multimodal=True,
|
208 |
concurrency_limit=200,
|
|
|
9 |
from transformers import LlavaProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
|
10 |
from threading import Thread
|
11 |
import re
|
12 |
+
import time
|
13 |
import torch
|
14 |
import cv2
|
15 |
from gradio_client import Client, file
|
16 |
|
17 |
+
|
18 |
def image_gen(prompt):
|
19 |
client = Client("KingNish/Image-Gen-Pro")
|
20 |
+
return client.predict("Image Generation", None, prompt, api_name="/image_gen_pro")
|
21 |
+
|
22 |
|
23 |
model_id = "llava-hf/llava-interleave-qwen-0.5b-hf"
|
24 |
|
|
|
30 |
|
31 |
def llava(message, history):
|
32 |
if message["files"]:
|
33 |
+
image = message["files"][0]
|
34 |
else:
|
35 |
for hist in history:
|
36 |
+
if type(hist[0]) == tuple:
|
37 |
image = hist[0][0]
|
38 |
+
|
39 |
txt = message["text"]
|
40 |
+
|
41 |
gr.Info("Analyzing image")
|
42 |
image = Image.open(image).convert("RGB")
|
43 |
prompt = f"<|im_start|>user <image>\n{txt}<|im_start|>assistant"
|
44 |
+
|
45 |
inputs = processor(prompt, image, return_tensors="pt")
|
46 |
return inputs
|
47 |
|
48 |
+
|
49 |
def extract_text_from_webpage(html_content):
|
50 |
soup = BeautifulSoup(html_content, 'html.parser')
|
51 |
for tag in soup(["script", "style", "header", "footer"]):
|
52 |
tag.extract()
|
53 |
return soup.get_text(strip=True)
|
54 |
|
55 |
+
|
56 |
def search(query):
|
57 |
term = query
|
58 |
start = 0
|
|
|
73 |
link = result.find("a", href=True)
|
74 |
link = link["href"]
|
75 |
try:
|
76 |
+
webpage = session.get(link, headers={
|
77 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
|
78 |
+
timeout=5, verify=False)
|
79 |
webpage.raise_for_status()
|
80 |
visible_text = extract_text_from_webpage(webpage.text)
|
81 |
if len(visible_text) > max_chars_per_page:
|
|
|
85 |
all_results.append({"link": link, "text": None})
|
86 |
return all_results
|
87 |
|
88 |
+
|
89 |
# Initialize inference clients for different models
|
90 |
client_gemma = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
|
91 |
client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
|
92 |
client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
|
93 |
client_yi = InferenceClient("01-ai/Yi-1.5-34B-Chat")
|
94 |
|
95 |
+
|
96 |
# Define the main chat function
|
97 |
def respond(message, history):
|
98 |
func_caller = []
|
|
|
106 |
|
107 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
108 |
thread.start()
|
109 |
+
|
110 |
buffer = ""
|
111 |
for new_text in streamer:
|
112 |
buffer += new_text
|
113 |
yield buffer
|
114 |
else:
|
115 |
functions_metadata = [
|
116 |
+
{"type": "function", "function": {"name": "web_search", "description": "Search query on google",
|
117 |
+
"parameters": {"type": "object", "properties": {
|
118 |
+
"query": {"type": "string", "description": "web search query"}},
|
119 |
+
"required": ["query"]}}},
|
120 |
+
{"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER",
|
121 |
+
"parameters": {"type": "object", "properties": {
|
122 |
+
"prompt": {"type": "string", "description": "A detailed prompt"}},
|
123 |
+
"required": ["prompt"]}}},
|
124 |
+
{"type": "function", "function": {"name": "image_generation", "description": "Generate image for user",
|
125 |
+
"parameters": {"type": "object", "properties": {
|
126 |
+
"query": {"type": "string",
|
127 |
+
"description": "image generation prompt"}},
|
128 |
+
"required": ["query"]}}},
|
129 |
+
{"type": "function",
|
130 |
+
"function": {"name": "image_qna", "description": "Answer question asked by user related to image",
|
131 |
+
"parameters": {"type": "object",
|
132 |
+
"properties": {"query": {"type": "string", "description": "Question by user"}},
|
133 |
+
"required": ["query"]}}},
|
134 |
]
|
135 |
|
136 |
for msg in history:
|
|
|
138 |
func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
|
139 |
|
140 |
message_text = message["text"]
|
141 |
+
func_caller.append({"role": "user",
|
142 |
+
"content": f'[SYSTEM]You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> [USER] {message_text}'})
|
143 |
+
|
144 |
response = client_gemma.chat_completion(func_caller, max_tokens=200)
|
145 |
response = str(response)
|
146 |
try:
|
147 |
response = response[int(response.find("{")):int(response.rindex("</"))]
|
148 |
except:
|
149 |
+
response = response[int(response.find("{")):(int(response.rfind("}")) + 1)]
|
150 |
response = response.replace("\\n", "")
|
151 |
response = response.replace("\\'", "'")
|
152 |
response = response.replace('\\"', '"')
|
153 |
response = response.replace('\\', '')
|
154 |
print(f"\n{response}")
|
155 |
+
|
156 |
try:
|
157 |
json_data = json.loads(str(response))
|
158 |
if json_data["name"] == "web_search":
|
|
|
165 |
for msg in history:
|
166 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
167 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
168 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
|
169 |
+
stream = client_mixtral.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True,
|
170 |
+
details=True, return_full_text=False)
|
171 |
output = ""
|
172 |
for response in stream:
|
173 |
+
if not response.token.text == "hello":
|
174 |
output += response.token.text
|
175 |
yield output
|
176 |
elif json_data["name"] == "image_generation":
|
|
|
182 |
yield gr.Image(image[1])
|
183 |
except:
|
184 |
client_sd3 = InferenceClient("stabilityai/stable-diffusion-3-medium-diffusers")
|
185 |
+
seed = random.randint(0, 999999)
|
186 |
image = client_sd3.text_to_image(query, negative_prompt=f"{seed}")
|
187 |
yield gr.Image(image)
|
188 |
elif json_data["name"] == "image_qna":
|
|
|
192 |
|
193 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
194 |
thread.start()
|
195 |
+
|
196 |
buffer = ""
|
197 |
for new_text in streamer:
|
198 |
buffer += new_text
|
|
|
202 |
for msg in history:
|
203 |
messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
|
204 |
messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
|
205 |
+
messages += f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
|
206 |
+
stream = client_yi.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True,
|
207 |
+
details=True, return_full_text=False)
|
208 |
output = ""
|
209 |
for response in stream:
|
210 |
if not response.token.text == "<|endoftext|>":
|
|
|
215 |
for msg in history:
|
216 |
messages += f"\n<|start_header_id|>user\n{str(msg[0])}<|end_header_id|>"
|
217 |
messages += f"\n<|start_header_id|>assistant\n{str(msg[1])}<|end_header_id|>"
|
218 |
+
messages += f"\n<|start_header_id|>user\n{message_text}<|end_header_id|>\n<|start_header_id|>assistant\n"
|
219 |
+
stream = client_llama.text_generation(messages, max_new_tokens=2000, do_sample=True, stream=True,
|
220 |
+
details=True, return_full_text=False)
|
221 |
output = ""
|
222 |
for response in stream:
|
223 |
if not response.token.text == "<|eot_id|>":
|
224 |
output += response.token.text
|
225 |
yield output
|
226 |
|
227 |
+
|
228 |
# Create the Gradio interface
|
229 |
demo = gr.ChatInterface(
|
230 |
fn=respond,
|
231 |
chatbot=gr.Chatbot(show_copy_button=True, likeable=True, layout="panel"),
|
232 |
+
description="# OpenGPT 4o \n ### chat, generate images, perform web searches, and Q&A with images.",
|
233 |
textbox=gr.MultimodalTextbox(),
|
234 |
multimodal=True,
|
235 |
concurrency_limit=200,
|