prithivMLmods commited on
Commit
d0a3095
·
verified ·
1 Parent(s): f344c9a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -135
app.py CHANGED
@@ -186,7 +186,7 @@ class VisitWebpageTool(Tool):
186
  return f"Error fetching the webpage: {str(e)}"
187
  except Exception as e:
188
  return f"An unexpected error occurred: {str(e)}"
189
-
190
  # rAgent Reasoning using Llama mode OpenAI
191
 
192
  from openai import OpenAI
@@ -199,12 +199,12 @@ ragent_client = OpenAI(
199
 
200
  SYSTEM_PROMPT = """
201
 
202
- "You are an expert assistant who solves tasks using Python code. Follow these steps:\n"
203
- "1. **Thought**: Explain your reasoning and plan for solving the task.\n"
204
- "2. **Code**: Write Python code to implement your solution.\n"
205
- "3. **Observation**: Analyze the output of the code and summarize the results.\n"
206
- "4. **Final Answer**: Provide a concise conclusion or final result.\n\n"
207
- f"Task: {task}"
208
 
209
  """
210
 
@@ -222,18 +222,21 @@ def ragent_reasoning(prompt: str, history: list[dict], max_tokens: int = 2048, t
222
  messages.append({"role": "user", "content": prompt})
223
  response = ""
224
  stream = ragent_client.chat.completions.create(
225
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
226
- max_tokens=max_tokens,
227
- stream=True,
228
- temperature=temperature,
229
- top_p=top_p,
230
- messages=messages,
231
  )
232
  for message in stream:
233
- token = message.choices[0].delta.content
234
- response += token
235
- yield response
236
 
 
 
 
237
  # Define prompt structure for Phi-4
238
  phi4_user_prompt = '<|user|>'
239
  phi4_assistant_prompt = '<|assistant|>'
@@ -250,20 +253,24 @@ phi4_model = AutoModelForCausalLM.from_pretrained(
250
  _attn_implementation="eager",
251
  )
252
 
 
 
 
 
253
  DESCRIPTION = """
254
  # Agent Dino 🌠"""
255
 
256
  css = '''
257
  h1 {
258
- text-align: center;
259
- display: block;
260
  }
261
 
262
  #duplicate-button {
263
- margin: auto;
264
- color: #fff;
265
- background: #1565c0;
266
- border-radius: 100vh;
267
  }
268
  '''
269
 
@@ -292,7 +299,7 @@ TTS_VOICES = [
292
  ]
293
 
294
  # Load multimodal processor and model (e.g. for OCR and image processing)
295
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
296
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
297
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
298
  MODEL_ID,
@@ -321,10 +328,10 @@ def clean_chat_history(chat_history):
321
  cleaned.append(msg)
322
  return cleaned
323
 
324
- # Stable Diffusion XL Pipeline for Image Generation
325
  # Model In Use : SG161222/RealVisXL_V5.0_Lightning
326
 
327
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
328
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
329
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
330
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
@@ -430,96 +437,17 @@ def detect_objects(image: np.ndarray):
430
  """Runs object detection on the input image."""
431
  results = yolo_detector(image, verbose=False)[0]
432
  detections = sv.Detections.from_ultralytics(results).with_nms()
433
-
434
  box_annotator = sv.BoxAnnotator()
435
  label_annotator = sv.LabelAnnotator()
436
-
437
  annotated_image = image.copy()
438
  annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
439
  annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
440
-
441
  return Image.fromarray(annotated_image)
442
 
443
- # GRPO Model Loading and Functions
444
-
445
- grpo_model_name = "prithivMLmods/SmolLM2-360M-Grpo-r999"
446
- grpo_tokenizer = AutoTokenizer.from_pretrained(grpo_model_name)
447
- grpo_model = AutoModelForCausalLM.from_pretrained(grpo_model_name).to(device)
448
-
449
- def get_user_prompt(prompt: str) -> str:
450
- match = re.search(r"<\|im_start\|>user\s*(.*?)\s*<\|im_end\|>", prompt, re.DOTALL)
451
- return match.group(1).strip() if match else "\n".join(
452
- line.strip()[4:].strip() if line.strip().lower().startswith("user") else line
453
- for line in prompt.splitlines() if not line.strip().lower().startswith("system")
454
- ).strip()
455
-
456
- def get_assistant_response(text: str) -> str:
457
- match = re.search(r"<\|im_start\|>assistant\s*(.*?)\s*<\|im_end\|>", text, re.DOTALL)
458
- return match.group(1).strip() if match else "\n".join(
459
- line for line in text.splitlines() if not line.strip().lower().startswith("assistant")
460
- ).strip()
461
-
462
- def generate_grpo_fn(prompt: str):
463
- messages = [
464
- {"role": "system", "content": "Please respond in this specific format ONLY:\n<thinking>\n input your reasoning behind your answer in between these reasoning tags.\n</thinking>\n<answer>\nyour answer in between these answer tags.\n</answer>\n"},
465
- {"role": "user", "content": prompt}
466
- ]
467
- input_text = grpo_tokenizer.apply_chat_template(messages, tokenize=False)
468
- inputs = grpo_tokenizer.encode(input_text, return_tensors="pt").to(device)
469
- streamer = TextIteratorStreamer(grpo_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
470
- generation_kwargs = {
471
- "input_ids": inputs,
472
- "streamer": streamer,
473
- "max_new_tokens": 200,
474
- "temperature": 0.2,
475
- "top_p": 0.9,
476
- "do_sample": True,
477
- "use_cache": False
478
- }
479
- thread = Thread(target=grpo_model.generate, kwargs=generation_kwargs)
480
- thread.start()
481
-
482
- outputs = []
483
- thinking_started = False
484
- answer_started = False
485
- collected_thinking = ""
486
- collected_answer = ""
487
-
488
- for new_text in streamer:
489
- outputs.append(new_text)
490
- full_output = "".join(outputs)
491
-
492
- if "<thinking>" in full_output and not thinking_started:
493
- thinking_started = True
494
- thinking_start_index = full_output.find("<thinking>") + len("<thinking>")
495
- collected_thinking = full_output[thinking_start_index:]
496
-
497
- elif thinking_started and "</thinking>" not in full_output:
498
- collected_thinking += new_text
499
-
500
- elif thinking_started and "</thinking>" in full_output and not answer_started:
501
- thinking_ended_index = full_output.find("</thinking>")
502
- collected_thinking = full_output[thinking_start_index:thinking_ended_index]
503
- answer_started = True
504
- answer_start_index = full_output.find("<answer>") + len("<answer>")
505
- collected_answer = full_output[answer_start_index:]
506
-
507
- elif answer_started and "</answer>" not in full_output:
508
- collected_answer += new_text
509
-
510
- elif answer_started and "</answer>" in full_output:
511
- answer_ended_index = full_output.find("</answer>")
512
- collected_answer = full_output[answer_start_index:answer_ended_index]
513
-
514
- if answer_started:
515
- # Yield only the answer part once answer section started
516
- yield collected_answer.strip()
517
- else:
518
- # While in thinking phase or before, yield full output for streaming effect
519
- yield "".join(outputs).replace("<thinking>", "").replace("</thinking>", "").replace("<answer>", "").replace("</answer>", "").strip()
520
-
521
-
522
- # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, @phi4 and @grpo commands
523
 
524
  @spaces.GPU
525
  def generate(
@@ -533,27 +461,17 @@ def generate(
533
  ):
534
  """
535
  Generates chatbot responses with support for multimodal input and special commands:
536
- - "@tts1" or "@tts2": triggers text-to-speech.
537
- - "@image": triggers image generation using the SDXL pipeline.
538
- - "@3d": triggers 3D model generation using the ShapE pipeline.
539
- - "@web": triggers a web search or webpage visit.
540
- - "@rAgent": initiates a reasoning chain using Llama mode.
541
- - "@yolo": triggers object detection using YOLO.
542
- - "@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.
543
- - "@grpo": triggers text generation using GRPO model with structured output.
544
  """
545
  text = input_dict["text"]
546
  files = input_dict.get("files", [])
547
 
548
- # --- GRPO Generation branch ---
549
- if text.strip().lower().startswith("@grpo"):
550
- prompt = text[len("@grpo"):].strip()
551
- yield "💡 Thinking using GRPO model..."
552
- for partial_response in generate_grpo_fn(prompt):
553
- yield partial_response
554
- return
555
-
556
-
557
  # --- 3D Generation branch ---
558
  if text.strip().lower().startswith("@3d"):
559
  prompt = text[len("@3d"):].strip()
@@ -572,7 +490,7 @@ def generate(
572
  new_filename = f"mesh_{uuid.uuid4()}.glb"
573
  new_filepath = os.path.join(static_folder, new_filename)
574
  shutil.copy(glb_path, new_filepath)
575
-
576
  yield gr.File(new_filepath)
577
  return
578
 
@@ -616,8 +534,8 @@ def generate(
616
  return
617
 
618
  # --- rAgent Reasoning branch ---
619
- if text.strip().lower().startswith("@ragent"):
620
- prompt = text[len("@ragent"):].strip()
621
  yield "📝 Initiating reasoning chain using Llama mode..."
622
  # Pass the current chat history (cleaned) to help inform the chain.
623
  for partial in ragent_reasoning(prompt, clean_chat_history(chat_history)):
@@ -686,7 +604,7 @@ def generate(
686
 
687
  # Initialize the streamer
688
  streamer = TextIteratorStreamer(phi4_processor, skip_prompt=True, skip_special_tokens=True)
689
-
690
  # Prepare generation kwargs
691
  generation_kwargs = {
692
  **inputs,
@@ -712,7 +630,7 @@ def generate(
712
  tts_prefix = "@tts"
713
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
714
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
715
-
716
  if is_tts and voice_index:
717
  voice = TTS_VOICES[voice_index - 1]
718
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
@@ -803,10 +721,9 @@ demo = gr.ChatInterface(
803
  ["@tts2 What causes rainbows to form?"],
804
  [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
805
  [{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
806
- ["@rAgent Explain how a binary search algorithm works."],
807
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
808
  ["@tts1 Explain Tower of Hanoi"],
809
- ["@grpo If there are 12 cookies in a dozen and you have 5 dozen, how many cookies do you have?"],
810
  ],
811
  cache_examples=False,
812
  type="messages",
@@ -814,10 +731,10 @@ demo = gr.ChatInterface(
814
  css=css,
815
  fill_height=True,
816
  textbox=gr.MultimodalTextbox(
817
- label="Query Input",
818
  file_types=["image", "audio"],
819
- file_count="multiple",
820
- placeholder="@tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, @grpo, default [plain text]"
821
  ),
822
  stop_btn="Stop Generation",
823
  multimodal=True,
 
186
  return f"Error fetching the webpage: {str(e)}"
187
  except Exception as e:
188
  return f"An unexpected error occurred: {str(e)}"
189
+
190
  # rAgent Reasoning using Llama mode OpenAI
191
 
192
  from openai import OpenAI
 
199
 
200
  SYSTEM_PROMPT = """
201
 
202
+ "You are an expert assistant who solves tasks using Python code. Follow these steps:\n"
203
+ "1. **Thought**: Explain your reasoning and plan for solving the task.\n"
204
+ "2. **Code**: Write Python code to implement your solution.\n"
205
+ "3. **Observation**: Analyze the output of the code and summarize the results.\n"
206
+ "4. **Final Answer**: Provide a concise conclusion or final result.\n\n"
207
+ f"Task: {task}"
208
 
209
  """
210
 
 
222
  messages.append({"role": "user", "content": prompt})
223
  response = ""
224
  stream = ragent_client.chat.completions.create(
225
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
226
+ max_tokens=max_tokens,
227
+ stream=True,
228
+ temperature=temperature,
229
+ top_p=top_p,
230
+ messages=messages,
231
  )
232
  for message in stream:
233
+ token = message.choices[0].delta.content
234
+ response += token
235
+ yield response
236
 
237
+ # ------------------------------------------------------------------------------
238
+ # New Phi-4 Multimodal Feature (Image & Audio)
239
+ # ------------------------------------------------------------------------------
240
  # Define prompt structure for Phi-4
241
  phi4_user_prompt = '<|user|>'
242
  phi4_assistant_prompt = '<|assistant|>'
 
253
  _attn_implementation="eager",
254
  )
255
 
256
+ # ------------------------------------------------------------------------------
257
+ # Gradio UI configuration
258
+ # ------------------------------------------------------------------------------
259
+
260
  DESCRIPTION = """
261
  # Agent Dino 🌠"""
262
 
263
  css = '''
264
  h1 {
265
+ text-align: center;
266
+ display: block;
267
  }
268
 
269
  #duplicate-button {
270
+ margin: auto;
271
+ color: #fff;
272
+ background: #1565c0;
273
+ border-radius: 100vh;
274
  }
275
  '''
276
 
 
299
  ]
300
 
301
  # Load multimodal processor and model (e.g. for OCR and image processing)
302
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
303
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
304
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
305
  MODEL_ID,
 
328
  cleaned.append(msg)
329
  return cleaned
330
 
331
+ # Stable Diffusion XL Pipeline for Image Generation
332
  # Model In Use : SG161222/RealVisXL_V5.0_Lightning
333
 
334
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
335
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
336
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
337
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
437
  """Runs object detection on the input image."""
438
  results = yolo_detector(image, verbose=False)[0]
439
  detections = sv.Detections.from_ultralytics(results).with_nms()
440
+
441
  box_annotator = sv.BoxAnnotator()
442
  label_annotator = sv.LabelAnnotator()
443
+
444
  annotated_image = image.copy()
445
  annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
446
  annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
447
+
448
  return Image.fromarray(annotated_image)
449
 
450
+ # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, and now @phi4 commands
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  @spaces.GPU
453
  def generate(
 
461
  ):
462
  """
463
  Generates chatbot responses with support for multimodal input and special commands:
464
+ - "@tts1" or "@tts2": triggers text-to-speech.
465
+ - "@image": triggers image generation using the SDXL pipeline.
466
+ - "@3d": triggers 3D model generation using the ShapE pipeline.
467
+ - "@web": triggers a web search or webpage visit.
468
+ - "@rAgent": initiates a reasoning chain using Llama mode.
469
+ - "@yolo": triggers object detection using YOLO.
470
+ - **"@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.**
 
471
  """
472
  text = input_dict["text"]
473
  files = input_dict.get("files", [])
474
 
 
 
 
 
 
 
 
 
 
475
  # --- 3D Generation branch ---
476
  if text.strip().lower().startswith("@3d"):
477
  prompt = text[len("@3d"):].strip()
 
490
  new_filename = f"mesh_{uuid.uuid4()}.glb"
491
  new_filepath = os.path.join(static_folder, new_filename)
492
  shutil.copy(glb_path, new_filepath)
493
+
494
  yield gr.File(new_filepath)
495
  return
496
 
 
534
  return
535
 
536
  # --- rAgent Reasoning branch ---
537
+ if text.strip().lower().startswith("@rAgent"):
538
+ prompt = text[len("@rAgent"):].strip()
539
  yield "📝 Initiating reasoning chain using Llama mode..."
540
  # Pass the current chat history (cleaned) to help inform the chain.
541
  for partial in ragent_reasoning(prompt, clean_chat_history(chat_history)):
 
604
 
605
  # Initialize the streamer
606
  streamer = TextIteratorStreamer(phi4_processor, skip_prompt=True, skip_special_tokens=True)
607
+
608
  # Prepare generation kwargs
609
  generation_kwargs = {
610
  **inputs,
 
630
  tts_prefix = "@tts"
631
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
632
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
633
+
634
  if is_tts and voice_index:
635
  voice = TTS_VOICES[voice_index - 1]
636
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
721
  ["@tts2 What causes rainbows to form?"],
722
  [{"text": "Summarize the letter", "files": ["examples/1.png"]}],
723
  [{"text": "@yolo", "files": ["examples/yolo.jpeg"]}],
724
+ ["@ragent Explain how a binary search algorithm works."],
725
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
726
  ["@tts1 Explain Tower of Hanoi"],
 
727
  ],
728
  cache_examples=False,
729
  type="messages",
 
731
  css=css,
732
  fill_height=True,
733
  textbox=gr.MultimodalTextbox(
734
+ label="Query Input",
735
  file_types=["image", "audio"],
736
+ file_count="multiple",
737
+ placeholder="@tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, default [plain text]"
738
  ),
739
  stop_btn="Stop Generation",
740
  multimodal=True,