Samarth991 commited on
Commit
926ff01
·
1 Parent(s): 44414ac

adding changes for detection output

Browse files
Files changed (3) hide show
  1. QA_bot.py +0 -1
  2. extract_tools.py +9 -5
  3. tool_utils/yolo_world.py +19 -11
QA_bot.py CHANGED
@@ -14,7 +14,6 @@ def display_mask_image(image_path):
14
  image = Image.open(image_path)
15
  st.image(image, caption='Final Mask', use_column_width=True)
16
 
17
-
18
  def tyre_synap_bot(filter_agent,image_file_path):
19
  if "messages" not in st.session_state:
20
  st.session_state.messages = []
 
14
  image = Image.open(image_path)
15
  st.image(image, caption='Final Mask', use_column_width=True)
16
 
 
17
  def tyre_synap_bot(filter_agent,image_file_path):
18
  if "messages" not in st.session_state:
19
  st.session_state.messages = []
extract_tools.py CHANGED
@@ -9,7 +9,7 @@ from langchain_core.tools import tool,Tool
9
  from langchain_community.tools import DuckDuckGoSearchResults
10
  from langchain_groq import ChatGroq
11
  from utils import draw_panoptic_segmentation
12
-
13
  from tool_utils.clip_segmentation import CLIPSEG
14
  from tool_utils.yolo_world import YoloWorld
15
  from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
@@ -120,6 +120,7 @@ def clipsegmentation_mask(input_data:str)->str:
120
  @tool
121
  def generate_bounding_box_tool(input_data:str)->str:
122
  "use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects"
 
123
  data = input_data.split(",")
124
  image_path = data[0]
125
  object_prompts = data[1:]
@@ -201,10 +202,13 @@ def get_all_tools():
201
  bounding_box_generator = Tool(
202
  name = 'Bounding Box Generator',
203
  func = generate_bounding_box_tool,
204
- description= "The tool helps to provide bounding boxes for the given image and list of objects\
205
- .Use this tool when user ask to provide bounding boxes for the objects.if user has not specified the names of the objects \
206
- then use the object extraction tool to identify the objects and then use this tool to generate the bounding boxes for the objects.\
207
- The input to this tool is the path of the image and list of objects for which bounding boxes are to be generated"
 
 
 
208
  )
209
 
210
  object_extractor = Tool(
 
9
  from langchain_community.tools import DuckDuckGoSearchResults
10
  from langchain_groq import ChatGroq
11
  from utils import draw_panoptic_segmentation
12
+ from typing import List
13
  from tool_utils.clip_segmentation import CLIPSEG
14
  from tool_utils.yolo_world import YoloWorld
15
  from tool_utils.image_qualitycheck import brightness_check,gaussian_noise_check,snr_check
 
120
  @tool
121
  def generate_bounding_box_tool(input_data:str)->str:
122
  "use this tool when its is required to detect object and provide bounding boxes for the given image and list of objects"
123
+ print(input_data)
124
  data = input_data.split(",")
125
  image_path = data[0]
126
  object_prompts = data[1:]
 
202
  bounding_box_generator = Tool(
203
  name = 'Bounding Box Generator',
204
  func = generate_bounding_box_tool,
205
+ description= """The tool helps to provide bounding boxes for the given image and list of objects
206
+ .Use this tool when user ask to provide bounding boxes for the objects.if user has not specified the names of the objects
207
+ then use the object extraction tool to identify the objects and then use this tool to generate the bounding boxes for the objects.
208
+ The input to this tool is the path of the image and list of objects for which bounding boxes are to be generated
209
+ For Example :
210
+ "action_input ": "image_store/<image_path>,person,dog,sand,"
211
+ """
212
  )
213
 
214
  object_extractor = Tool(
tool_utils/yolo_world.py CHANGED
@@ -6,7 +6,7 @@ from typing import List
6
  import torch
7
  import random
8
  from ultralytics import YOLOWorld
9
-
10
  class YoloWorld:
11
  def __init__(self,model_name = "yolov8x-worldv2.pt"):
12
  self.model = YOLOWorld(model_name)
@@ -57,7 +57,11 @@ class YoloWorld:
57
  cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, color_dict[label], thickness=tf, lineType=cv2.LINE_AA)
58
  return rgb_frame_copy
59
 
60
-
 
 
 
 
61
  def run_yolo_infer(self,image_path:str,object_prompts:List):
62
  processed_predictions = []
63
  bounding_boxes = []
@@ -73,20 +77,24 @@ class YoloWorld:
73
  labels.append(result.names[int(box.cls.cpu())])
74
  scores.append(round(float(box.conf.cpu()),2))
75
 
76
- processed_predictions.append(dict(boxes= torch.tensor(bounding_boxes),
 
77
  labels= labels,
78
- scores=torch.tensor(scores))
79
- )
 
 
80
  detected_image = self.draw_bboxes(rgb_frame=image_path,
81
  boxes=processed_predictions[0]['boxes'],
82
  labels=processed_predictions[0]['labels']
83
  )
84
-
 
85
  cv2.imwrite('final_mask.png', cv2.cvtColor(detected_image,cv2.COLOR_BGR2RGB))
86
- return "Predicted image : final_mask.png . Details :{}".format(processed_predictions[0])
87
 
88
- # if __name__ == "__main__":
89
- # yolo = YoloWorld()
90
- # predicted_data = yolo.run_yolo_infer('../image_store/demo2.jpg',['person','hat','building'])
91
- # print(predicted_data)
92
 
 
6
  import torch
7
  import random
8
  from ultralytics import YOLOWorld
9
+ import json
10
  class YoloWorld:
11
  def __init__(self,model_name = "yolov8x-worldv2.pt"):
12
  self.model = YOLOWorld(model_name)
 
57
  cv2.putText(rgb_frame_copy, str(label), (c1[0], c1[1] - 2), 0, tl / 3, color_dict[label], thickness=tf, lineType=cv2.LINE_AA)
58
  return rgb_frame_copy
59
 
60
+ def format_detections(self,boxes,labels):
61
+ text = ""
62
+ for box ,label in zip(boxes,labels):
63
+ text+="{}\tBounding Box :{}\n".format(label,box)
64
+ return (text)
65
  def run_yolo_infer(self,image_path:str,object_prompts:List):
66
  processed_predictions = []
67
  bounding_boxes = []
 
77
  labels.append(result.names[int(box.cls.cpu())])
78
  scores.append(round(float(box.conf.cpu()),2))
79
 
80
+ processed_predictions.append(dict(
81
+ boxes= torch.tensor(bounding_boxes),
82
  labels= labels,
83
+ scores= torch.tensor(scores)
84
+ )
85
+ )
86
+
87
  detected_image = self.draw_bboxes(rgb_frame=image_path,
88
  boxes=processed_predictions[0]['boxes'],
89
  labels=processed_predictions[0]['labels']
90
  )
91
+ predicted_data = self.format_detections(bounding_boxes,labels)
92
+ # save image
93
  cv2.imwrite('final_mask.png', cv2.cvtColor(detected_image,cv2.COLOR_BGR2RGB))
94
+ return "Predicted image : final_mask.png . \nDetails :\n{}".format(predicted_data)
95
 
96
+ if __name__ == "__main__":
97
+ yolo = YoloWorld()
98
+ predicted_data = yolo.run_yolo_infer('../image_store/demo2.jpg',['person','hat','building'])
99
+ print(predicted_data)
100