prithivMLmods commited on
Commit
f344c9a
·
verified ·
1 Parent(s): 59d9c9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -42
app.py CHANGED
@@ -186,7 +186,7 @@ class VisitWebpageTool(Tool):
186
  return f"Error fetching the webpage: {str(e)}"
187
  except Exception as e:
188
  return f"An unexpected error occurred: {str(e)}"
189
-
190
  # rAgent Reasoning using Llama mode OpenAI
191
 
192
  from openai import OpenAI
@@ -199,12 +199,12 @@ ragent_client = OpenAI(
199
 
200
  SYSTEM_PROMPT = """
201
 
202
- "You are an expert assistant who solves tasks using Python code. Follow these steps:\n"
203
- "1. **Thought**: Explain your reasoning and plan for solving the task.\n"
204
- "2. **Code**: Write Python code to implement your solution.\n"
205
- "3. **Observation**: Analyze the output of the code and summarize the results.\n"
206
- "4. **Final Answer**: Provide a concise conclusion or final result.\n\n"
207
- f"Task: {task}"
208
 
209
  """
210
 
@@ -222,17 +222,17 @@ def ragent_reasoning(prompt: str, history: list[dict], max_tokens: int = 2048, t
222
  messages.append({"role": "user", "content": prompt})
223
  response = ""
224
  stream = ragent_client.chat.completions.create(
225
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
226
- max_tokens=max_tokens,
227
- stream=True,
228
- temperature=temperature,
229
- top_p=top_p,
230
- messages=messages,
231
  )
232
  for message in stream:
233
- token = message.choices[0].delta.content
234
- response += token
235
- yield response
236
 
237
  # Define prompt structure for Phi-4
238
  phi4_user_prompt = '<|user|>'
@@ -255,15 +255,15 @@ DESCRIPTION = """
255
 
256
  css = '''
257
  h1 {
258
- text-align: center;
259
- display: block;
260
  }
261
 
262
  #duplicate-button {
263
- margin: auto;
264
- color: #fff;
265
- background: #1565c0;
266
- border-radius: 100vh;
267
  }
268
  '''
269
 
@@ -292,7 +292,7 @@ TTS_VOICES = [
292
  ]
293
 
294
  # Load multimodal processor and model (e.g. for OCR and image processing)
295
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
296
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
297
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
298
  MODEL_ID,
@@ -321,10 +321,10 @@ def clean_chat_history(chat_history):
321
  cleaned.append(msg)
322
  return cleaned
323
 
324
- # Stable Diffusion XL Pipeline for Image Generation
325
  # Model In Use : SG161222/RealVisXL_V5.0_Lightning
326
 
327
- MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
328
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
329
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
330
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
@@ -430,17 +430,96 @@ def detect_objects(image: np.ndarray):
430
  """Runs object detection on the input image."""
431
  results = yolo_detector(image, verbose=False)[0]
432
  detections = sv.Detections.from_ultralytics(results).with_nms()
433
-
434
  box_annotator = sv.BoxAnnotator()
435
  label_annotator = sv.LabelAnnotator()
436
-
437
  annotated_image = image.copy()
438
  annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
439
  annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
440
-
441
  return Image.fromarray(annotated_image)
442
 
443
- # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, and now @phi4 commands
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  @spaces.GPU
446
  def generate(
@@ -454,17 +533,27 @@ def generate(
454
  ):
455
  """
456
  Generates chatbot responses with support for multimodal input and special commands:
457
- - "@tts1" or "@tts2": triggers text-to-speech.
458
- - "@image": triggers image generation using the SDXL pipeline.
459
- - "@3d": triggers 3D model generation using the ShapE pipeline.
460
- - "@web": triggers a web search or webpage visit.
461
- - "@rAgent": initiates a reasoning chain using Llama mode.
462
- - "@yolo": triggers object detection using YOLO.
463
- - **"@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.**
 
464
  """
465
  text = input_dict["text"]
466
  files = input_dict.get("files", [])
467
 
 
 
 
 
 
 
 
 
 
468
  # --- 3D Generation branch ---
469
  if text.strip().lower().startswith("@3d"):
470
  prompt = text[len("@3d"):].strip()
@@ -483,7 +572,7 @@ def generate(
483
  new_filename = f"mesh_{uuid.uuid4()}.glb"
484
  new_filepath = os.path.join(static_folder, new_filename)
485
  shutil.copy(glb_path, new_filepath)
486
-
487
  yield gr.File(new_filepath)
488
  return
489
 
@@ -597,7 +686,7 @@ def generate(
597
 
598
  # Initialize the streamer
599
  streamer = TextIteratorStreamer(phi4_processor, skip_prompt=True, skip_special_tokens=True)
600
-
601
  # Prepare generation kwargs
602
  generation_kwargs = {
603
  **inputs,
@@ -623,7 +712,7 @@ def generate(
623
  tts_prefix = "@tts"
624
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
625
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
626
-
627
  if is_tts and voice_index:
628
  voice = TTS_VOICES[voice_index - 1]
629
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
@@ -717,6 +806,7 @@ demo = gr.ChatInterface(
717
  ["@rAgent Explain how a binary search algorithm works."],
718
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
719
  ["@tts1 Explain Tower of Hanoi"],
 
720
  ],
721
  cache_examples=False,
722
  type="messages",
@@ -724,10 +814,10 @@ demo = gr.ChatInterface(
724
  css=css,
725
  fill_height=True,
726
  textbox=gr.MultimodalTextbox(
727
- label="Query Input",
728
  file_types=["image", "audio"],
729
- file_count="multiple",
730
- placeholder="@tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, default [plain text]"
731
  ),
732
  stop_btn="Stop Generation",
733
  multimodal=True,
 
186
  return f"Error fetching the webpage: {str(e)}"
187
  except Exception as e:
188
  return f"An unexpected error occurred: {str(e)}"
189
+
190
  # rAgent Reasoning using Llama mode OpenAI
191
 
192
  from openai import OpenAI
 
199
 
200
  SYSTEM_PROMPT = """
201
 
202
+ "You are an expert assistant who solves tasks using Python code. Follow these steps:\n"
203
+ "1. **Thought**: Explain your reasoning and plan for solving the task.\n"
204
+ "2. **Code**: Write Python code to implement your solution.\n"
205
+ "3. **Observation**: Analyze the output of the code and summarize the results.\n"
206
+ "4. **Final Answer**: Provide a concise conclusion or final result.\n\n"
207
+ f"Task: {task}"
208
 
209
  """
210
 
 
222
  messages.append({"role": "user", "content": prompt})
223
  response = ""
224
  stream = ragent_client.chat.completions.create(
225
+ model="meta-llama/Meta-Llama-3.1-8B-Instruct",
226
+ max_tokens=max_tokens,
227
+ stream=True,
228
+ temperature=temperature,
229
+ top_p=top_p,
230
+ messages=messages,
231
  )
232
  for message in stream:
233
+ token = message.choices[0].delta.content
234
+ response += token
235
+ yield response
236
 
237
  # Define prompt structure for Phi-4
238
  phi4_user_prompt = '<|user|>'
 
255
 
256
  css = '''
257
  h1 {
258
+ text-align: center;
259
+ display: block;
260
  }
261
 
262
  #duplicate-button {
263
+ margin: auto;
264
+ color: #fff;
265
+ background: #1565c0;
266
+ border-radius: 100vh;
267
  }
268
  '''
269
 
 
292
  ]
293
 
294
  # Load multimodal processor and model (e.g. for OCR and image processing)
295
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
296
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
297
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
298
  MODEL_ID,
 
321
  cleaned.append(msg)
322
  return cleaned
323
 
324
+ # Stable Diffusion XL Pipeline for Image Generation
325
  # Model In Use : SG161222/RealVisXL_V5.0_Lightning
326
 
327
+ MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
328
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
329
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
330
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
 
430
  """Runs object detection on the input image."""
431
  results = yolo_detector(image, verbose=False)[0]
432
  detections = sv.Detections.from_ultralytics(results).with_nms()
433
+
434
  box_annotator = sv.BoxAnnotator()
435
  label_annotator = sv.LabelAnnotator()
436
+
437
  annotated_image = image.copy()
438
  annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
439
  annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
440
+
441
  return Image.fromarray(annotated_image)
442
 
443
+ # GRPO Model Loading and Functions
444
+
445
+ grpo_model_name = "prithivMLmods/SmolLM2-360M-Grpo-r999"
446
+ grpo_tokenizer = AutoTokenizer.from_pretrained(grpo_model_name)
447
+ grpo_model = AutoModelForCausalLM.from_pretrained(grpo_model_name).to(device)
448
+
449
+ def get_user_prompt(prompt: str) -> str:
450
+ match = re.search(r"<\|im_start\|>user\s*(.*?)\s*<\|im_end\|>", prompt, re.DOTALL)
451
+ return match.group(1).strip() if match else "\n".join(
452
+ line.strip()[4:].strip() if line.strip().lower().startswith("user") else line
453
+ for line in prompt.splitlines() if not line.strip().lower().startswith("system")
454
+ ).strip()
455
+
456
+ def get_assistant_response(text: str) -> str:
457
+ match = re.search(r"<\|im_start\|>assistant\s*(.*?)\s*<\|im_end\|>", text, re.DOTALL)
458
+ return match.group(1).strip() if match else "\n".join(
459
+ line for line in text.splitlines() if not line.strip().lower().startswith("assistant")
460
+ ).strip()
461
+
462
+ def generate_grpo_fn(prompt: str):
463
+ messages = [
464
+ {"role": "system", "content": "Please respond in this specific format ONLY:\n<thinking>\n input your reasoning behind your answer in between these reasoning tags.\n</thinking>\n<answer>\nyour answer in between these answer tags.\n</answer>\n"},
465
+ {"role": "user", "content": prompt}
466
+ ]
467
+ input_text = grpo_tokenizer.apply_chat_template(messages, tokenize=False)
468
+ inputs = grpo_tokenizer.encode(input_text, return_tensors="pt").to(device)
469
+ streamer = TextIteratorStreamer(grpo_tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
470
+ generation_kwargs = {
471
+ "input_ids": inputs,
472
+ "streamer": streamer,
473
+ "max_new_tokens": 200,
474
+ "temperature": 0.2,
475
+ "top_p": 0.9,
476
+ "do_sample": True,
477
+ "use_cache": False
478
+ }
479
+ thread = Thread(target=grpo_model.generate, kwargs=generation_kwargs)
480
+ thread.start()
481
+
482
+ outputs = []
483
+ thinking_started = False
484
+ answer_started = False
485
+ collected_thinking = ""
486
+ collected_answer = ""
487
+
488
+ for new_text in streamer:
489
+ outputs.append(new_text)
490
+ full_output = "".join(outputs)
491
+
492
+ if "<thinking>" in full_output and not thinking_started:
493
+ thinking_started = True
494
+ thinking_start_index = full_output.find("<thinking>") + len("<thinking>")
495
+ collected_thinking = full_output[thinking_start_index:]
496
+
497
+ elif thinking_started and "</thinking>" not in full_output:
498
+ collected_thinking += new_text
499
+
500
+ elif thinking_started and "</thinking>" in full_output and not answer_started:
501
+ thinking_ended_index = full_output.find("</thinking>")
502
+ collected_thinking = full_output[thinking_start_index:thinking_ended_index]
503
+ answer_started = True
504
+ answer_start_index = full_output.find("<answer>") + len("<answer>")
505
+ collected_answer = full_output[answer_start_index:]
506
+
507
+ elif answer_started and "</answer>" not in full_output:
508
+ collected_answer += new_text
509
+
510
+ elif answer_started and "</answer>" in full_output:
511
+ answer_ended_index = full_output.find("</answer>")
512
+ collected_answer = full_output[answer_start_index:answer_ended_index]
513
+
514
+ if answer_started:
515
+ # Yield only the answer part once answer section started
516
+ yield collected_answer.strip()
517
+ else:
518
+ # While in thinking phase or before, yield full output for streaming effect
519
+ yield "".join(outputs).replace("<thinking>", "").replace("</thinking>", "").replace("<answer>", "").replace("</answer>", "").strip()
520
+
521
+
522
+ # Chat Generation Function with support for @tts, @image, @3d, @web, @rAgent, @yolo, @phi4 and @grpo commands
523
 
524
  @spaces.GPU
525
  def generate(
 
533
  ):
534
  """
535
  Generates chatbot responses with support for multimodal input and special commands:
536
+ - "@tts1" or "@tts2": triggers text-to-speech.
537
+ - "@image": triggers image generation using the SDXL pipeline.
538
+ - "@3d": triggers 3D model generation using the ShapE pipeline.
539
+ - "@web": triggers a web search or webpage visit.
540
+ - "@rAgent": initiates a reasoning chain using Llama mode.
541
+ - "@yolo": triggers object detection using YOLO.
542
+ - "@phi4": triggers multimodal (image/audio) processing using the Phi-4 model.
543
+ - "@grpo": triggers text generation using GRPO model with structured output.
544
  """
545
  text = input_dict["text"]
546
  files = input_dict.get("files", [])
547
 
548
+ # --- GRPO Generation branch ---
549
+ if text.strip().lower().startswith("@grpo"):
550
+ prompt = text[len("@grpo"):].strip()
551
+ yield "💡 Thinking using GRPO model..."
552
+ for partial_response in generate_grpo_fn(prompt):
553
+ yield partial_response
554
+ return
555
+
556
+
557
  # --- 3D Generation branch ---
558
  if text.strip().lower().startswith("@3d"):
559
  prompt = text[len("@3d"):].strip()
 
572
  new_filename = f"mesh_{uuid.uuid4()}.glb"
573
  new_filepath = os.path.join(static_folder, new_filename)
574
  shutil.copy(glb_path, new_filepath)
575
+
576
  yield gr.File(new_filepath)
577
  return
578
 
 
686
 
687
  # Initialize the streamer
688
  streamer = TextIteratorStreamer(phi4_processor, skip_prompt=True, skip_special_tokens=True)
689
+
690
  # Prepare generation kwargs
691
  generation_kwargs = {
692
  **inputs,
 
712
  tts_prefix = "@tts"
713
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
714
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
715
+
716
  if is_tts and voice_index:
717
  voice = TTS_VOICES[voice_index - 1]
718
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
806
  ["@rAgent Explain how a binary search algorithm works."],
807
  ["@web Is Grok-3 Beats DeepSeek-R1 at Reasoning ?"],
808
  ["@tts1 Explain Tower of Hanoi"],
809
+ ["@grpo If there are 12 cookies in a dozen and you have 5 dozen, how many cookies do you have?"],
810
  ],
811
  cache_examples=False,
812
  type="messages",
 
814
  css=css,
815
  fill_height=True,
816
  textbox=gr.MultimodalTextbox(
817
+ label="Query Input",
818
  file_types=["image", "audio"],
819
+ file_count="multiple",
820
+ placeholder="@tts1, @tts2, @image, @3d, @phi4 [image, audio], @rAgent, @web, @yolo, @grpo, default [plain text]"
821
  ),
822
  stop_btn="Stop Generation",
823
  multimodal=True,