prithivMLmods commited on
Commit
14bfced
Β·
verified Β·
1 Parent(s): e448df3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -39
app.py CHANGED
@@ -31,9 +31,7 @@ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
31
  from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
32
  from diffusers.utils import export_to_ply
33
 
34
- # -----------------------------------------------------------------------------
35
  # Global constants and helper functions
36
- # -----------------------------------------------------------------------------
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
39
 
@@ -52,9 +50,7 @@ def glb_to_data_url(glb_path: str) -> str:
52
  b64_data = base64.b64encode(data).decode("utf-8")
53
  return f"data:model/gltf-binary;base64,{b64_data}"
54
 
55
- # -----------------------------------------------------------------------------
56
  # Model class for Text-to-3D Generation (ShapE)
57
- # -----------------------------------------------------------------------------
58
 
59
  class Model:
60
  def __init__(self):
@@ -113,9 +109,7 @@ class Model:
113
  export_to_ply(images[0], ply_path.name)
114
  return self.to_glb(ply_path.name)
115
 
116
- # -----------------------------------------------------------------------------
117
  # New Tools for Web Functionality using DuckDuckGo and smolagents
118
- # -----------------------------------------------------------------------------
119
 
120
  from typing import Any, Optional
121
  from smolagents.tools import Tool
@@ -186,14 +180,68 @@ class VisitWebpageTool(Tool):
186
  return f"Error fetching the webpage: {str(e)}"
187
  except Exception as e:
188
  return f"An unexpected error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- # -----------------------------------------------------------------------------
191
  # Gradio UI configuration
192
- # -----------------------------------------------------------------------------
193
 
194
  DESCRIPTION = """
195
- # Agent Dino 🌠
196
- """
197
 
198
  css = '''
199
  h1 {
@@ -215,11 +263,9 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
215
 
216
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
217
 
218
- # -----------------------------------------------------------------------------
219
  # Load Models and Pipelines for Chat, Image, and Multimodal Processing
220
- # -----------------------------------------------------------------------------
221
-
222
  # Load the text-only model and tokenizer (for pure text chat)
 
223
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
224
  tokenizer = AutoTokenizer.from_pretrained(model_id)
225
  model = AutoModelForCausalLM.from_pretrained(
@@ -244,9 +290,7 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
244
  torch_dtype=torch.float16
245
  ).to("cuda").eval()
246
 
247
- # -----------------------------------------------------------------------------
248
  # Asynchronous text-to-speech
249
- # -----------------------------------------------------------------------------
250
 
251
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
252
  """Convert text to speech using Edge TTS and save as MP3"""
@@ -254,9 +298,7 @@ async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
254
  await communicate.save(output_file)
255
  return output_file
256
 
257
- # -----------------------------------------------------------------------------
258
  # Utility function to clean conversation history
259
- # -----------------------------------------------------------------------------
260
 
261
  def clean_chat_history(chat_history):
262
  """
@@ -269,9 +311,7 @@ def clean_chat_history(chat_history):
269
  cleaned.append(msg)
270
  return cleaned
271
 
272
- # -----------------------------------------------------------------------------
273
  # Stable Diffusion XL Pipeline for Image Generation
274
- # -----------------------------------------------------------------------------
275
 
276
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
277
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
@@ -350,9 +390,7 @@ def generate_image_fn(
350
  image_paths = [save_image(img) for img in images]
351
  return image_paths, seed
352
 
353
- # -----------------------------------------------------------------------------
354
  # Text-to-3D Generation using the ShapE Pipeline
355
- # -----------------------------------------------------------------------------
356
 
357
  @spaces.GPU(duration=120, enable_queue=True)
358
  def generate_3d_fn(
@@ -371,9 +409,7 @@ def generate_3d_fn(
371
  glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
372
  return glb_path, seed
373
 
374
- # -----------------------------------------------------------------------------
375
- # Chat Generation Function with support for @tts, @image, @3d, and now @web commands
376
- # -----------------------------------------------------------------------------
377
 
378
  @spaces.GPU
379
  def generate(
@@ -386,14 +422,12 @@ def generate(
386
  repetition_penalty: float = 1.2,
387
  ):
388
  """
389
- Generates chatbot responses with support for multimodal input, TTS, image generation,
390
- 3D model generation, and web search/visit.
391
-
392
- Special commands:
393
  - "@tts1" or "@tts2": triggers text-to-speech.
394
  - "@image": triggers image generation using the SDXL pipeline.
395
  - "@3d": triggers 3D model generation using the ShapE pipeline.
396
- - "@web": triggers a web search or webpage visit. Use "visit" after @web to fetch a page.
 
397
  """
398
  text = input_dict["text"]
399
  files = input_dict.get("files", [])
@@ -401,7 +435,7 @@ def generate(
401
  # --- 3D Generation branch ---
402
  if text.strip().lower().startswith("@3d"):
403
  prompt = text[len("@3d"):].strip()
404
- yield "Hold tight, generating a 3D mesh GLB file....."
405
  glb_path, used_seed = generate_3d_fn(
406
  prompt=prompt,
407
  seed=1,
@@ -423,7 +457,7 @@ def generate(
423
  # --- Image Generation branch ---
424
  if text.strip().lower().startswith("@image"):
425
  prompt = text[len("@image"):].strip()
426
- yield "Generating image..."
427
  image_paths, used_seed = generate_image_fn(
428
  prompt=prompt,
429
  negative_prompt="",
@@ -446,19 +480,28 @@ def generate(
446
  # If the command starts with "visit", then treat the rest as a URL
447
  if web_command.lower().startswith("visit"):
448
  url = web_command[len("visit"):].strip()
449
- yield "Visiting webpage..."
450
  visitor = VisitWebpageTool()
451
  content = visitor.forward(url)
452
  yield content
453
  else:
454
  # Otherwise, treat the rest as a search query.
455
  query = web_command
456
- yield "Perform a web search ..."
457
  searcher = DuckDuckGoSearchTool()
458
  results = searcher.forward(query)
459
  yield results
460
  return
461
 
 
 
 
 
 
 
 
 
 
462
  # --- Text and TTS branch ---
463
  tts_prefix = "@tts"
464
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
@@ -496,7 +539,7 @@ def generate(
496
  thread.start()
497
 
498
  buffer = ""
499
- yield "Thinking..."
500
  for new_text in streamer:
501
  buffer += new_text
502
  buffer = buffer.replace("<|im_end|>", "")
@@ -535,9 +578,7 @@ def generate(
535
  output_file = asyncio.run(text_to_speech(final_response, voice))
536
  yield gr.Audio(output_file, autoplay=True)
537
 
538
- # -----------------------------------------------------------------------------
539
  # Gradio Chat Interface Setup and Launch
540
- # -----------------------------------------------------------------------------
541
 
542
  demo = gr.ChatInterface(
543
  fn=generate,
@@ -553,8 +594,9 @@ demo = gr.ChatInterface(
553
  ["@3d A birthday cupcake with cherry"],
554
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
555
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
556
- ["Write a Python Code String Reverse With Example!"],
557
  ["@web latest breakthroughs in renewable energy"],
 
558
  ],
559
  cache_examples=False,
560
  type="messages",
@@ -570,10 +612,8 @@ demo = gr.ChatInterface(
570
  if not os.path.exists("static"):
571
  os.makedirs("static")
572
 
573
- # Mount the static folder onto the FastAPI app so that GLB files are served.
574
  from fastapi.staticfiles import StaticFiles
575
  demo.app.mount("/static", StaticFiles(directory="static"), name="static")
576
 
577
  if __name__ == "__main__":
578
- # Launch without the unsupported static_dirs parameter.
579
  demo.queue(max_size=20).launch(share=True)
 
31
  from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
32
  from diffusers.utils import export_to_ply
33
 
 
34
  # Global constants and helper functions
 
35
 
36
  MAX_SEED = np.iinfo(np.int32).max
37
 
 
50
  b64_data = base64.b64encode(data).decode("utf-8")
51
  return f"data:model/gltf-binary;base64,{b64_data}"
52
 
 
53
  # Model class for Text-to-3D Generation (ShapE)
 
54
 
55
  class Model:
56
  def __init__(self):
 
109
  export_to_ply(images[0], ply_path.name)
110
  return self.to_glb(ply_path.name)
111
 
 
112
  # New Tools for Web Functionality using DuckDuckGo and smolagents
 
113
 
114
  from typing import Any, Optional
115
  from smolagents.tools import Tool
 
180
  return f"Error fetching the webpage: {str(e)}"
181
  except Exception as e:
182
  return f"An unexpected error occurred: {str(e)}"
183
+
184
+ # New Feature: rAgent Reasoning using Llama mode OpenAI
185
+
186
+ from openai import OpenAI
187
+
188
+ ACCESS_TOKEN = os.getenv("HF_TOKEN")
189
+ ragent_client = OpenAI(
190
+ base_url="https://api-inference.huggingface.co/v1/",
191
+ api_key=ACCESS_TOKEN,
192
+ )
193
+
194
+ SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
195
+
196
+ To do so, you must follow a structured reasoning process in a cycle of:
197
+
198
+ 1. **Thought:**
199
+ - Analyze the problem and explain your reasoning.
200
+ - Identify any necessary tools or techniques.
201
+
202
+ 2. **Code:**
203
+ - Implement the solution using Python.
204
+ - Enclose the code block with `<end_code>`.
205
+
206
+ 3. **Observation:**
207
+ - Explain the output and verify correctness.
208
+
209
+ 4. **Final Answer:**
210
+ - Summarize the solution clearly.
211
+
212
+ Always adhere to the **Thought β†’ Code β†’ Observation β†’ Final Answer** structure.
213
+ """
214
+
215
+ def ragent_reasoning(prompt: str, history: list[dict], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95):
216
+ """
217
+ Uses the Llama mode OpenAI model to perform a structured reasoning chain.
218
+ """
219
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
220
+ # Incorporate conversation history (if any)
221
+ for msg in history:
222
+ if msg.get("role") == "user":
223
+ messages.append({"role": "user", "content": msg["content"]})
224
+ elif msg.get("role") == "assistant":
225
+ messages.append({"role": "assistant", "content": msg["content"]})
226
+ messages.append({"role": "user", "content": prompt})
227
+ response = ""
228
+ stream = ragent_client.chat.completions.create(
229
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
230
+ max_tokens=max_tokens,
231
+ stream=True,
232
+ temperature=temperature,
233
+ top_p=top_p,
234
+ messages=messages,
235
+ )
236
+ for message in stream:
237
+ token = message.choices[0].delta.content
238
+ response += token
239
+ yield response
240
 
 
241
  # Gradio UI configuration
 
242
 
243
  DESCRIPTION = """
244
+ # Agent Dino 🌠 """
 
245
 
246
  css = '''
247
  h1 {
 
263
 
264
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
265
 
 
266
  # Load Models and Pipelines for Chat, Image, and Multimodal Processing
 
 
267
  # Load the text-only model and tokenizer (for pure text chat)
268
+
269
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
270
  tokenizer = AutoTokenizer.from_pretrained(model_id)
271
  model = AutoModelForCausalLM.from_pretrained(
 
290
  torch_dtype=torch.float16
291
  ).to("cuda").eval()
292
 
 
293
  # Asynchronous text-to-speech
 
294
 
295
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
296
  """Convert text to speech using Edge TTS and save as MP3"""
 
298
  await communicate.save(output_file)
299
  return output_file
300
 
 
301
  # Utility function to clean conversation history
 
302
 
303
  def clean_chat_history(chat_history):
304
  """
 
311
  cleaned.append(msg)
312
  return cleaned
313
 
 
314
  # Stable Diffusion XL Pipeline for Image Generation
 
315
 
316
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
317
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
390
  image_paths = [save_image(img) for img in images]
391
  return image_paths, seed
392
 
 
393
  # Text-to-3D Generation using the ShapE Pipeline
 
394
 
395
  @spaces.GPU(duration=120, enable_queue=True)
396
  def generate_3d_fn(
 
409
  glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
410
  return glb_path, seed
411
 
412
+ # Chat Generation Function with support for @tts, @image, @3d, @web, and @rAgent commands
 
 
413
 
414
  @spaces.GPU
415
  def generate(
 
422
  repetition_penalty: float = 1.2,
423
  ):
424
  """
425
+ Generates chatbot responses with support for multimodal input and special commands:
 
 
 
426
  - "@tts1" or "@tts2": triggers text-to-speech.
427
  - "@image": triggers image generation using the SDXL pipeline.
428
  - "@3d": triggers 3D model generation using the ShapE pipeline.
429
+ - "@web": triggers a web search or webpage visit.
430
+ - "@rAgent": initiates a reasoning chain using Llama mode OpenAI.
431
  """
432
  text = input_dict["text"]
433
  files = input_dict.get("files", [])
 
435
  # --- 3D Generation branch ---
436
  if text.strip().lower().startswith("@3d"):
437
  prompt = text[len("@3d"):].strip()
438
+ yield "πŸŒ€ Hold tight, generating a 3D mesh GLB file....."
439
  glb_path, used_seed = generate_3d_fn(
440
  prompt=prompt,
441
  seed=1,
 
457
  # --- Image Generation branch ---
458
  if text.strip().lower().startswith("@image"):
459
  prompt = text[len("@image"):].strip()
460
+ yield "πŸͺ§ Generating image..."
461
  image_paths, used_seed = generate_image_fn(
462
  prompt=prompt,
463
  negative_prompt="",
 
480
  # If the command starts with "visit", then treat the rest as a URL
481
  if web_command.lower().startswith("visit"):
482
  url = web_command[len("visit"):].strip()
483
+ yield "🌍 Visiting webpage..."
484
  visitor = VisitWebpageTool()
485
  content = visitor.forward(url)
486
  yield content
487
  else:
488
  # Otherwise, treat the rest as a search query.
489
  query = web_command
490
+ yield "🧀 Performing a web search ..."
491
  searcher = DuckDuckGoSearchTool()
492
  results = searcher.forward(query)
493
  yield results
494
  return
495
 
496
+ # --- rAgent Reasoning branch ---
497
+ if text.strip().lower().startswith("@ragent"):
498
+ prompt = text[len("@ragent"):].strip()
499
+ yield "Initiating reasoning chain using Llama mode..."
500
+ # Pass the current chat history (cleaned) to help inform the chain.
501
+ for partial in ragent_reasoning(prompt, clean_chat_history(chat_history)):
502
+ yield partial
503
+ return
504
+
505
  # --- Text and TTS branch ---
506
  tts_prefix = "@tts"
507
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
 
539
  thread.start()
540
 
541
  buffer = ""
542
+ yield "πŸ€” Thinking..."
543
  for new_text in streamer:
544
  buffer += new_text
545
  buffer = buffer.replace("<|im_end|>", "")
 
578
  output_file = asyncio.run(text_to_speech(final_response, voice))
579
  yield gr.Audio(output_file, autoplay=True)
580
 
 
581
  # Gradio Chat Interface Setup and Launch
 
582
 
583
  demo = gr.ChatInterface(
584
  fn=generate,
 
594
  ["@3d A birthday cupcake with cherry"],
595
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
596
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
597
+ ["@rAgent Explain how a binary search algorithm works."],
598
  ["@web latest breakthroughs in renewable energy"],
599
+
600
  ],
601
  cache_examples=False,
602
  type="messages",
 
612
  if not os.path.exists("static"):
613
  os.makedirs("static")
614
 
 
615
  from fastapi.staticfiles import StaticFiles
616
  demo.app.mount("/static", StaticFiles(directory="static"), name="static")
617
 
618
  if __name__ == "__main__":
 
619
  demo.queue(max_size=20).launch(share=True)