Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -31,9 +31,7 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
|
31 |
from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
|
32 |
from diffusers.utils import export_to_ply
|
33 |
|
34 |
-
# -----------------------------------------------------------------------------
|
35 |
# Global constants and helper functions
|
36 |
-
# -----------------------------------------------------------------------------
|
37 |
|
38 |
MAX_SEED = np.iinfo(np.int32).max
|
39 |
|
@@ -52,9 +50,7 @@ def glb_to_data_url(glb_path: str) -> str:
|
|
52 |
b64_data = base64.b64encode(data).decode("utf-8")
|
53 |
return f"data:model/gltf-binary;base64,{b64_data}"
|
54 |
|
55 |
-
# -----------------------------------------------------------------------------
|
56 |
# Model class for Text-to-3D Generation (ShapE)
|
57 |
-
# -----------------------------------------------------------------------------
|
58 |
|
59 |
class Model:
|
60 |
def __init__(self):
|
@@ -113,9 +109,7 @@ class Model:
|
|
113 |
export_to_ply(images[0], ply_path.name)
|
114 |
return self.to_glb(ply_path.name)
|
115 |
|
116 |
-
# -----------------------------------------------------------------------------
|
117 |
# New Tools for Web Functionality using DuckDuckGo and smolagents
|
118 |
-
# -----------------------------------------------------------------------------
|
119 |
|
120 |
from typing import Any, Optional
|
121 |
from smolagents.tools import Tool
|
@@ -186,14 +180,68 @@ class VisitWebpageTool(Tool):
|
|
186 |
return f"Error fetching the webpage: {str(e)}"
|
187 |
except Exception as e:
|
188 |
return f"An unexpected error occurred: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
-
# -----------------------------------------------------------------------------
|
191 |
# Gradio UI configuration
|
192 |
-
# -----------------------------------------------------------------------------
|
193 |
|
194 |
DESCRIPTION = """
|
195 |
-
# Agent Dino π
|
196 |
-
"""
|
197 |
|
198 |
css = '''
|
199 |
h1 {
|
@@ -215,11 +263,9 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
215 |
|
216 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
217 |
|
218 |
-
# -----------------------------------------------------------------------------
|
219 |
# Load Models and Pipelines for Chat, Image, and Multimodal Processing
|
220 |
-
# -----------------------------------------------------------------------------
|
221 |
-
|
222 |
# Load the text-only model and tokenizer (for pure text chat)
|
|
|
223 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
224 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
225 |
model = AutoModelForCausalLM.from_pretrained(
|
@@ -244,9 +290,7 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
244 |
torch_dtype=torch.float16
|
245 |
).to("cuda").eval()
|
246 |
|
247 |
-
# -----------------------------------------------------------------------------
|
248 |
# Asynchronous text-to-speech
|
249 |
-
# -----------------------------------------------------------------------------
|
250 |
|
251 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
252 |
"""Convert text to speech using Edge TTS and save as MP3"""
|
@@ -254,9 +298,7 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
|
254 |
await communicate.save(output_file)
|
255 |
return output_file
|
256 |
|
257 |
-
# -----------------------------------------------------------------------------
|
258 |
# Utility function to clean conversation history
|
259 |
-
# -----------------------------------------------------------------------------
|
260 |
|
261 |
def clean_chat_history(chat_history):
|
262 |
"""
|
@@ -269,9 +311,7 @@ def clean_chat_history(chat_history):
|
|
269 |
cleaned.append(msg)
|
270 |
return cleaned
|
271 |
|
272 |
-
# -----------------------------------------------------------------------------
|
273 |
# Stable Diffusion XL Pipeline for Image Generation
|
274 |
-
# -----------------------------------------------------------------------------
|
275 |
|
276 |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
277 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
@@ -350,9 +390,7 @@ def generate_image_fn(
|
|
350 |
image_paths = [save_image(img) for img in images]
|
351 |
return image_paths, seed
|
352 |
|
353 |
-
# -----------------------------------------------------------------------------
|
354 |
# Text-to-3D Generation using the ShapE Pipeline
|
355 |
-
# -----------------------------------------------------------------------------
|
356 |
|
357 |
@spaces.GPU(duration=120, enable_queue=True)
|
358 |
def generate_3d_fn(
|
@@ -371,9 +409,7 @@ def generate_3d_fn(
|
|
371 |
glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
|
372 |
return glb_path, seed
|
373 |
|
374 |
-
#
|
375 |
-
# Chat Generation Function with support for @tts, @image, @3d, and now @web commands
|
376 |
-
# -----------------------------------------------------------------------------
|
377 |
|
378 |
@spaces.GPU
|
379 |
def generate(
|
@@ -386,14 +422,12 @@ def generate(
|
|
386 |
repetition_penalty: float = 1.2,
|
387 |
):
|
388 |
"""
|
389 |
-
Generates chatbot responses with support for multimodal input
|
390 |
-
3D model generation, and web search/visit.
|
391 |
-
|
392 |
-
Special commands:
|
393 |
- "@tts1" or "@tts2": triggers text-to-speech.
|
394 |
- "@image": triggers image generation using the SDXL pipeline.
|
395 |
- "@3d": triggers 3D model generation using the ShapE pipeline.
|
396 |
-
- "@web": triggers a web search or webpage visit.
|
|
|
397 |
"""
|
398 |
text = input_dict["text"]
|
399 |
files = input_dict.get("files", [])
|
@@ -401,7 +435,7 @@ def generate(
|
|
401 |
# --- 3D Generation branch ---
|
402 |
if text.strip().lower().startswith("@3d"):
|
403 |
prompt = text[len("@3d"):].strip()
|
404 |
-
yield "Hold tight, generating a 3D mesh GLB file....."
|
405 |
glb_path, used_seed = generate_3d_fn(
|
406 |
prompt=prompt,
|
407 |
seed=1,
|
@@ -423,7 +457,7 @@ def generate(
|
|
423 |
# --- Image Generation branch ---
|
424 |
if text.strip().lower().startswith("@image"):
|
425 |
prompt = text[len("@image"):].strip()
|
426 |
-
yield "Generating image..."
|
427 |
image_paths, used_seed = generate_image_fn(
|
428 |
prompt=prompt,
|
429 |
negative_prompt="",
|
@@ -446,19 +480,28 @@ def generate(
|
|
446 |
# If the command starts with "visit", then treat the rest as a URL
|
447 |
if web_command.lower().startswith("visit"):
|
448 |
url = web_command[len("visit"):].strip()
|
449 |
-
yield "Visiting webpage..."
|
450 |
visitor = VisitWebpageTool()
|
451 |
content = visitor.forward(url)
|
452 |
yield content
|
453 |
else:
|
454 |
# Otherwise, treat the rest as a search query.
|
455 |
query = web_command
|
456 |
-
yield "
|
457 |
searcher = DuckDuckGoSearchTool()
|
458 |
results = searcher.forward(query)
|
459 |
yield results
|
460 |
return
|
461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
# --- Text and TTS branch ---
|
463 |
tts_prefix = "@tts"
|
464 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
@@ -496,7 +539,7 @@ def generate(
|
|
496 |
thread.start()
|
497 |
|
498 |
buffer = ""
|
499 |
-
yield "Thinking..."
|
500 |
for new_text in streamer:
|
501 |
buffer += new_text
|
502 |
buffer = buffer.replace("<|im_end|>", "")
|
@@ -535,9 +578,7 @@ def generate(
|
|
535 |
output_file = asyncio.run(text_to_speech(final_response, voice))
|
536 |
yield gr.Audio(output_file, autoplay=True)
|
537 |
|
538 |
-
# -----------------------------------------------------------------------------
|
539 |
# Gradio Chat Interface Setup and Launch
|
540 |
-
# -----------------------------------------------------------------------------
|
541 |
|
542 |
demo = gr.ChatInterface(
|
543 |
fn=generate,
|
@@ -553,8 +594,9 @@ demo = gr.ChatInterface(
|
|
553 |
["@3d A birthday cupcake with cherry"],
|
554 |
[{"text": "summarize the letter", "files": ["examples/1.png"]}],
|
555 |
["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
|
556 |
-
["
|
557 |
["@web latest breakthroughs in renewable energy"],
|
|
|
558 |
],
|
559 |
cache_examples=False,
|
560 |
type="messages",
|
@@ -570,10 +612,8 @@ demo = gr.ChatInterface(
|
|
570 |
if not os.path.exists("static"):
|
571 |
os.makedirs("static")
|
572 |
|
573 |
-
# Mount the static folder onto the FastAPI app so that GLB files are served.
|
574 |
from fastapi.staticfiles import StaticFiles
|
575 |
demo.app.mount("/static", StaticFiles(directory="static"), name="static")
|
576 |
|
577 |
if __name__ == "__main__":
|
578 |
-
# Launch without the unsupported static_dirs parameter.
|
579 |
demo.queue(max_size=20).launch(share=True)
|
|
|
31 |
from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
|
32 |
from diffusers.utils import export_to_ply
|
33 |
|
|
|
34 |
# Global constants and helper functions
|
|
|
35 |
|
36 |
MAX_SEED = np.iinfo(np.int32).max
|
37 |
|
|
|
50 |
b64_data = base64.b64encode(data).decode("utf-8")
|
51 |
return f"data:model/gltf-binary;base64,{b64_data}"
|
52 |
|
|
|
53 |
# Model class for Text-to-3D Generation (ShapE)
|
|
|
54 |
|
55 |
class Model:
|
56 |
def __init__(self):
|
|
|
109 |
export_to_ply(images[0], ply_path.name)
|
110 |
return self.to_glb(ply_path.name)
|
111 |
|
|
|
112 |
# New Tools for Web Functionality using DuckDuckGo and smolagents
|
|
|
113 |
|
114 |
from typing import Any, Optional
|
115 |
from smolagents.tools import Tool
|
|
|
180 |
return f"Error fetching the webpage: {str(e)}"
|
181 |
except Exception as e:
|
182 |
return f"An unexpected error occurred: {str(e)}"
|
183 |
+
|
184 |
+
# New Feature: rAgent Reasoning using Llama mode OpenAI
|
185 |
+
|
186 |
+
from openai import OpenAI
|
187 |
+
|
188 |
+
ACCESS_TOKEN = os.getenv("HF_TOKEN")
|
189 |
+
ragent_client = OpenAI(
|
190 |
+
base_url="https://api-inference.huggingface.co/v1/",
|
191 |
+
api_key=ACCESS_TOKEN,
|
192 |
+
)
|
193 |
+
|
194 |
+
SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
|
195 |
+
|
196 |
+
To do so, you must follow a structured reasoning process in a cycle of:
|
197 |
+
|
198 |
+
1. **Thought:**
|
199 |
+
- Analyze the problem and explain your reasoning.
|
200 |
+
- Identify any necessary tools or techniques.
|
201 |
+
|
202 |
+
2. **Code:**
|
203 |
+
- Implement the solution using Python.
|
204 |
+
- Enclose the code block with `<end_code>`.
|
205 |
+
|
206 |
+
3. **Observation:**
|
207 |
+
- Explain the output and verify correctness.
|
208 |
+
|
209 |
+
4. **Final Answer:**
|
210 |
+
- Summarize the solution clearly.
|
211 |
+
|
212 |
+
Always adhere to the **Thought β Code β Observation β Final Answer** structure.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def ragent_reasoning(prompt: str, history: list[dict], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95):
|
216 |
+
"""
|
217 |
+
Uses the Llama mode OpenAI model to perform a structured reasoning chain.
|
218 |
+
"""
|
219 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
220 |
+
# Incorporate conversation history (if any)
|
221 |
+
for msg in history:
|
222 |
+
if msg.get("role") == "user":
|
223 |
+
messages.append({"role": "user", "content": msg["content"]})
|
224 |
+
elif msg.get("role") == "assistant":
|
225 |
+
messages.append({"role": "assistant", "content": msg["content"]})
|
226 |
+
messages.append({"role": "user", "content": prompt})
|
227 |
+
response = ""
|
228 |
+
stream = ragent_client.chat.completions.create(
|
229 |
+
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
230 |
+
max_tokens=max_tokens,
|
231 |
+
stream=True,
|
232 |
+
temperature=temperature,
|
233 |
+
top_p=top_p,
|
234 |
+
messages=messages,
|
235 |
+
)
|
236 |
+
for message in stream:
|
237 |
+
token = message.choices[0].delta.content
|
238 |
+
response += token
|
239 |
+
yield response
|
240 |
|
|
|
241 |
# Gradio UI configuration
|
|
|
242 |
|
243 |
DESCRIPTION = """
|
244 |
+
# Agent Dino π """
|
|
|
245 |
|
246 |
css = '''
|
247 |
h1 {
|
|
|
263 |
|
264 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
265 |
|
|
|
266 |
# Load Models and Pipelines for Chat, Image, and Multimodal Processing
|
|
|
|
|
267 |
# Load the text-only model and tokenizer (for pure text chat)
|
268 |
+
|
269 |
model_id = "prithivMLmods/FastThink-0.5B-Tiny"
|
270 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
271 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
290 |
torch_dtype=torch.float16
|
291 |
).to("cuda").eval()
|
292 |
|
|
|
293 |
# Asynchronous text-to-speech
|
|
|
294 |
|
295 |
async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
|
296 |
"""Convert text to speech using Edge TTS and save as MP3"""
|
|
|
298 |
await communicate.save(output_file)
|
299 |
return output_file
|
300 |
|
|
|
301 |
# Utility function to clean conversation history
|
|
|
302 |
|
303 |
def clean_chat_history(chat_history):
|
304 |
"""
|
|
|
311 |
cleaned.append(msg)
|
312 |
return cleaned
|
313 |
|
|
|
314 |
# Stable Diffusion XL Pipeline for Image Generation
|
|
|
315 |
|
316 |
MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
|
317 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
|
|
|
390 |
image_paths = [save_image(img) for img in images]
|
391 |
return image_paths, seed
|
392 |
|
|
|
393 |
# Text-to-3D Generation using the ShapE Pipeline
|
|
|
394 |
|
395 |
@spaces.GPU(duration=120, enable_queue=True)
|
396 |
def generate_3d_fn(
|
|
|
409 |
glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
|
410 |
return glb_path, seed
|
411 |
|
412 |
+
# Chat Generation Function with support for @tts, @image, @3d, @web, and @rAgent commands
|
|
|
|
|
413 |
|
414 |
@spaces.GPU
|
415 |
def generate(
|
|
|
422 |
repetition_penalty: float = 1.2,
|
423 |
):
|
424 |
"""
|
425 |
+
Generates chatbot responses with support for multimodal input and special commands:
|
|
|
|
|
|
|
426 |
- "@tts1" or "@tts2": triggers text-to-speech.
|
427 |
- "@image": triggers image generation using the SDXL pipeline.
|
428 |
- "@3d": triggers 3D model generation using the ShapE pipeline.
|
429 |
+
- "@web": triggers a web search or webpage visit.
|
430 |
+
- "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
|
431 |
"""
|
432 |
text = input_dict["text"]
|
433 |
files = input_dict.get("files", [])
|
|
|
435 |
# --- 3D Generation branch ---
|
436 |
if text.strip().lower().startswith("@3d"):
|
437 |
prompt = text[len("@3d"):].strip()
|
438 |
+
yield "π Hold tight, generating a 3D mesh GLB file....."
|
439 |
glb_path, used_seed = generate_3d_fn(
|
440 |
prompt=prompt,
|
441 |
seed=1,
|
|
|
457 |
# --- Image Generation branch ---
|
458 |
if text.strip().lower().startswith("@image"):
|
459 |
prompt = text[len("@image"):].strip()
|
460 |
+
yield "πͺ§ Generating image..."
|
461 |
image_paths, used_seed = generate_image_fn(
|
462 |
prompt=prompt,
|
463 |
negative_prompt="",
|
|
|
480 |
# If the command starts with "visit", then treat the rest as a URL
|
481 |
if web_command.lower().startswith("visit"):
|
482 |
url = web_command[len("visit"):].strip()
|
483 |
+
yield "π Visiting webpage..."
|
484 |
visitor = VisitWebpageTool()
|
485 |
content = visitor.forward(url)
|
486 |
yield content
|
487 |
else:
|
488 |
# Otherwise, treat the rest as a search query.
|
489 |
query = web_command
|
490 |
+
yield "𧀠Performing a web search ..."
|
491 |
searcher = DuckDuckGoSearchTool()
|
492 |
results = searcher.forward(query)
|
493 |
yield results
|
494 |
return
|
495 |
|
496 |
+
# --- rAgent Reasoning branch ---
|
497 |
+
if text.strip().lower().startswith("@ragent"):
|
498 |
+
prompt = text[len("@ragent"):].strip()
|
499 |
+
yield "Initiating reasoning chain using Llama mode..."
|
500 |
+
# Pass the current chat history (cleaned) to help inform the chain.
|
501 |
+
for partial in ragent_reasoning(prompt, clean_chat_history(chat_history)):
|
502 |
+
yield partial
|
503 |
+
return
|
504 |
+
|
505 |
# --- Text and TTS branch ---
|
506 |
tts_prefix = "@tts"
|
507 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
|
|
539 |
thread.start()
|
540 |
|
541 |
buffer = ""
|
542 |
+
yield "π€ Thinking..."
|
543 |
for new_text in streamer:
|
544 |
buffer += new_text
|
545 |
buffer = buffer.replace("<|im_end|>", "")
|
|
|
578 |
output_file = asyncio.run(text_to_speech(final_response, voice))
|
579 |
yield gr.Audio(output_file, autoplay=True)
|
580 |
|
|
|
581 |
# Gradio Chat Interface Setup and Launch
|
|
|
582 |
|
583 |
demo = gr.ChatInterface(
|
584 |
fn=generate,
|
|
|
594 |
["@3d A birthday cupcake with cherry"],
|
595 |
[{"text": "summarize the letter", "files": ["examples/1.png"]}],
|
596 |
["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
|
597 |
+
["@rAgent Explain how a binary search algorithm works."],
|
598 |
["@web latest breakthroughs in renewable energy"],
|
599 |
+
|
600 |
],
|
601 |
cache_examples=False,
|
602 |
type="messages",
|
|
|
612 |
if not os.path.exists("static"):
|
613 |
os.makedirs("static")
|
614 |
|
|
|
615 |
from fastapi.staticfiles import StaticFiles
|
616 |
demo.app.mount("/static", StaticFiles(directory="static"), name="static")
|
617 |
|
618 |
if __name__ == "__main__":
|
|
|
619 |
demo.queue(max_size=20).launch(share=True)
|