Ashoka74 commited on
Commit
d068918
Β·
verified Β·
1 Parent(s): a0b057f

Update gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +82 -40
gradio_demo.py CHANGED
@@ -953,56 +953,92 @@ def process_image(input_image, input_text):
953
  task = DinoxTask(
954
  image_url=image_url,
955
  prompts=[TextPrompt(text=input_text)]
 
956
  )
 
957
  client.run_task(task)
958
  result = task.result
959
  objects = result.objects
960
 
961
 
962
 
963
- for obj in objects:
964
- input_boxes.append(obj.bbox)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
  confidences.append(obj.score)
966
  cls_name = obj.category.lower().strip()
967
  class_names.append(cls_name)
968
  class_ids.append(class_name_to_id[cls_name])
969
-
970
- input_boxes = np.array(input_boxes)
 
971
  class_ids = np.array(class_ids)
972
-
 
 
 
 
 
973
  # Initialize SAM2
974
- torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
975
- if torch.cuda.get_device_properties(0).major >= 8:
976
- torch.backends.cuda.matmul.allow_tf32 = True
977
- torch.backends.cudnn.allow_tf32 = True
978
 
979
- sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
980
- sam2_predictor = SAM2ImagePredictor(sam2_model)
981
- sam2_predictor.set_image(input_image)
982
 
983
  # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
984
 
985
 
986
  # Get masks from SAM2
987
- masks, scores, logits = sam2_predictor.predict(
988
- point_coords=None,
989
- point_labels=None,
990
- box=input_boxes,
991
- multimask_output=False,
992
- )
 
993
  if masks.ndim == 4:
994
  masks = masks.squeeze(1)
995
 
996
  # Create visualization
997
- labels = [f"{class_name} {confidence:.2f}"
998
- for class_name, confidence in zip(class_names, confidences)]
999
 
 
 
 
 
 
 
1000
  detections = sv.Detections(
1001
- xyxy=input_boxes,
1002
- mask=masks.astype(bool),
1003
- class_id=class_ids
1004
- )
1005
-
1006
  box_annotator = sv.BoxAnnotator()
1007
  label_annotator = sv.LabelAnnotator()
1008
  mask_annotator = sv.MaskAnnotator()
@@ -1157,6 +1193,8 @@ with block:
1157
  with gr.Group():
1158
  gr.Markdown("Extract Foreground")
1159
  input_image = gr.Image(type="numpy", label="Input Image", height=480)
 
 
1160
  find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1161
  text_prompt = gr.Textbox(
1162
  label="Text Prompt",
@@ -1311,25 +1349,29 @@ with block:
1311
 
1312
  # # return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
1313
 
1314
- def update_position(background, x_pos, y_pos, scale):
1315
- if background is None:
 
1316
  return None
1317
- # Restore a fresh copy of the original background
1318
- fresh_bg = bg_manager.original_bg.copy()
1319
- # Composite the foreground once
1320
  return mask_mover.create_composite(fresh_bg, float(x_pos), float(y_pos), float(scale))
1321
 
1322
-
1323
 
 
1324
  # Create an instance of BackgroundManager
1325
  bg_manager = BackgroundManager()
1326
 
1327
- def update_position(background, x_pos, y_pos, scale):
1328
- if background is None:
1329
- return None
1330
- fresh_bg = bg_manager.original_bg.copy() # Start from a clean original background
1331
- # Composite the extracted foreground onto fresh_bg
1332
- return mask_mover.create_composite(fresh_bg, float(x_pos), float(y_pos), float(scale))
 
 
 
 
 
1333
 
1334
 
1335
 
@@ -1341,19 +1383,19 @@ with block:
1341
  )
1342
 
1343
  x_slider.change(
1344
- fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1345
  inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1346
  outputs=[input_bg]
1347
  )
1348
 
1349
  y_slider.change(
1350
- fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1351
  inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1352
  outputs=[input_bg]
1353
  )
1354
 
1355
  fg_scale_slider.change(
1356
- fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1357
  inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1358
  outputs=[input_bg]
1359
  )
 
953
  task = DinoxTask(
954
  image_url=image_url,
955
  prompts=[TextPrompt(text=input_text)]
956
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
957
  )
958
+
959
  client.run_task(task)
960
  result = task.result
961
  objects = result.objects
962
 
963
 
964
 
965
+ # for obj in objects:
966
+ # input_boxes.append(obj.bbox)
967
+ # confidences.append(obj.score)
968
+ # cls_name = obj.category.lower().strip()
969
+ # class_names.append(cls_name)
970
+ # class_ids.append(class_name_to_id[cls_name])
971
+
972
+ # input_boxes = np.array(input_boxes)
973
+ # class_ids = np.array(class_ids)
974
+
975
+ classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
976
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
977
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
978
+
979
+ boxes = []
980
+ masks = []
981
+ confidences = []
982
+ class_names = []
983
+ class_ids = []
984
+
985
+ for idx, obj in enumerate(predictions):
986
+ boxes.append(obj.bbox)
987
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
988
  confidences.append(obj.score)
989
  cls_name = obj.category.lower().strip()
990
  class_names.append(cls_name)
991
  class_ids.append(class_name_to_id[cls_name])
992
+
993
+ boxes = np.array(boxes)
994
+ masks = np.array(masks)
995
  class_ids = np.array(class_ids)
996
+ labels = [
997
+ f"{class_name} {confidence:.2f}"
998
+ for class_name, confidence
999
+ in zip(class_names, confidences)
1000
+ ]
1001
+
1002
  # Initialize SAM2
1003
+ # torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
1004
+ # if torch.cuda.get_device_properties(0).major >= 8:
1005
+ # torch.backends.cuda.matmul.allow_tf32 = True
1006
+ # torch.backends.cudnn.allow_tf32 = True
1007
 
1008
+ # sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
1009
+ # sam2_predictor = SAM2ImagePredictor(sam2_model)
1010
+ # sam2_predictor.set_image(input_image)
1011
 
1012
  # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
1013
 
1014
 
1015
  # Get masks from SAM2
1016
+ # masks, scores, logits = sam2_predictor.predict(
1017
+ # point_coords=None,
1018
+ # point_labels=None,
1019
+ # box=input_boxes,
1020
+ # multimask_output=False,
1021
+ # )
1022
+
1023
  if masks.ndim == 4:
1024
  masks = masks.squeeze(1)
1025
 
1026
  # Create visualization
1027
+ # labels = [f"{class_name} {confidence:.2f}"
1028
+ # for class_name, confidence in zip(class_names, confidences)]
1029
 
1030
+ # detections = sv.Detections(
1031
+ # xyxy=input_boxes,
1032
+ # mask=masks.astype(bool),
1033
+ # class_id=class_ids
1034
+ # )
1035
+
1036
  detections = sv.Detections(
1037
+ xyxy = boxes,
1038
+ mask = masks.astype(bool),
1039
+ class_id = class_ids,
1040
+ )
1041
+
1042
  box_annotator = sv.BoxAnnotator()
1043
  label_annotator = sv.LabelAnnotator()
1044
  mask_annotator = sv.MaskAnnotator()
 
1193
  with gr.Group():
1194
  gr.Markdown("Extract Foreground")
1195
  input_image = gr.Image(type="numpy", label="Input Image", height=480)
1196
+ with gr.Row():
1197
+ with gr.Group():
1198
  find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1199
  text_prompt = gr.Textbox(
1200
  label="Text Prompt",
 
1349
 
1350
  # # return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
1351
 
1352
+ def update_position(self, background, x_pos, y_pos, scale, *args):
1353
+ if self.original_bg is None:
1354
+ print("No original background set.")
1355
  return None
1356
+ fresh_bg = self.original_bg.copy() # Start from a clean original background
 
 
1357
  return mask_mover.create_composite(fresh_bg, float(x_pos), float(y_pos), float(scale))
1358
 
 
1359
 
1360
+
1361
  # Create an instance of BackgroundManager
1362
  bg_manager = BackgroundManager()
1363
 
1364
+ def update_position_wrapper(background, x_pos, y_pos, scale):
1365
+ return bg_manager.update_position(background, x_pos, y_pos, scale)
1366
+
1367
+
1368
+
1369
+ # def update_position(background, x_pos, y_pos, scale):
1370
+ # if background is None:
1371
+ # return None
1372
+ # fresh_bg = bg_manager.original_bg.copy() # Start from a clean original background
1373
+ # # Composite the extracted foreground onto fresh_bg
1374
+ # return mask_mover.create_composite(fresh_bg, float(x_pos), float(y_pos), float(scale))
1375
 
1376
 
1377
 
 
1383
  )
1384
 
1385
  x_slider.change(
1386
+ fn=update_position_wrapper,
1387
  inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1388
  outputs=[input_bg]
1389
  )
1390
 
1391
  y_slider.change(
1392
+ fn=update_position_wrapper,
1393
  inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1394
  outputs=[input_bg]
1395
  )
1396
 
1397
  fg_scale_slider.change(
1398
+ fn=update_position_wrapper,
1399
  inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1400
  outputs=[input_bg]
1401
  )