Spaces:
Runtime error
Runtime error
Update gradio_demo.py
Browse files- gradio_demo.py +82 -40
gradio_demo.py
CHANGED
@@ -953,56 +953,92 @@ def process_image(input_image, input_text):
|
|
953 |
task = DinoxTask(
|
954 |
image_url=image_url,
|
955 |
prompts=[TextPrompt(text=input_text)]
|
|
|
956 |
)
|
|
|
957 |
client.run_task(task)
|
958 |
result = task.result
|
959 |
objects = result.objects
|
960 |
|
961 |
|
962 |
|
963 |
-
for obj in objects:
|
964 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
965 |
confidences.append(obj.score)
|
966 |
cls_name = obj.category.lower().strip()
|
967 |
class_names.append(cls_name)
|
968 |
class_ids.append(class_name_to_id[cls_name])
|
969 |
-
|
970 |
-
|
|
|
971 |
class_ids = np.array(class_ids)
|
972 |
-
|
|
|
|
|
|
|
|
|
|
|
973 |
# Initialize SAM2
|
974 |
-
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
975 |
-
if torch.cuda.get_device_properties(0).major >= 8:
|
976 |
-
|
977 |
-
|
978 |
|
979 |
-
sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
|
980 |
-
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
981 |
-
sam2_predictor.set_image(input_image)
|
982 |
|
983 |
# sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
|
984 |
|
985 |
|
986 |
# Get masks from SAM2
|
987 |
-
masks, scores, logits = sam2_predictor.predict(
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
)
|
|
|
993 |
if masks.ndim == 4:
|
994 |
masks = masks.squeeze(1)
|
995 |
|
996 |
# Create visualization
|
997 |
-
labels = [f"{class_name} {confidence:.2f}"
|
998 |
-
|
999 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
detections = sv.Detections(
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
box_annotator = sv.BoxAnnotator()
|
1007 |
label_annotator = sv.LabelAnnotator()
|
1008 |
mask_annotator = sv.MaskAnnotator()
|
@@ -1157,6 +1193,8 @@ with block:
|
|
1157 |
with gr.Group():
|
1158 |
gr.Markdown("Extract Foreground")
|
1159 |
input_image = gr.Image(type="numpy", label="Input Image", height=480)
|
|
|
|
|
1160 |
find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
|
1161 |
text_prompt = gr.Textbox(
|
1162 |
label="Text Prompt",
|
@@ -1311,25 +1349,29 @@ with block:
|
|
1311 |
|
1312 |
# # return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
|
1313 |
|
1314 |
-
def update_position(background, x_pos, y_pos, scale):
|
1315 |
-
if
|
|
|
1316 |
return None
|
1317 |
-
|
1318 |
-
fresh_bg = bg_manager.original_bg.copy()
|
1319 |
-
# Composite the foreground once
|
1320 |
return mask_mover.create_composite(fresh_bg, float(x_pos), float(y_pos), float(scale))
|
1321 |
|
1322 |
-
|
1323 |
|
|
|
1324 |
# Create an instance of BackgroundManager
|
1325 |
bg_manager = BackgroundManager()
|
1326 |
|
1327 |
-
def
|
1328 |
-
|
1329 |
-
|
1330 |
-
|
1331 |
-
|
1332 |
-
|
|
|
|
|
|
|
|
|
|
|
1333 |
|
1334 |
|
1335 |
|
@@ -1341,19 +1383,19 @@ with block:
|
|
1341 |
)
|
1342 |
|
1343 |
x_slider.change(
|
1344 |
-
fn=
|
1345 |
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
|
1346 |
outputs=[input_bg]
|
1347 |
)
|
1348 |
|
1349 |
y_slider.change(
|
1350 |
-
fn=
|
1351 |
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
|
1352 |
outputs=[input_bg]
|
1353 |
)
|
1354 |
|
1355 |
fg_scale_slider.change(
|
1356 |
-
fn=
|
1357 |
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
|
1358 |
outputs=[input_bg]
|
1359 |
)
|
|
|
953 |
task = DinoxTask(
|
954 |
image_url=image_url,
|
955 |
prompts=[TextPrompt(text=input_text)]
|
956 |
+
targets=[DetectionTarget.BBox, DetectionTarget.Mask]
|
957 |
)
|
958 |
+
|
959 |
client.run_task(task)
|
960 |
result = task.result
|
961 |
objects = result.objects
|
962 |
|
963 |
|
964 |
|
965 |
+
# for obj in objects:
|
966 |
+
# input_boxes.append(obj.bbox)
|
967 |
+
# confidences.append(obj.score)
|
968 |
+
# cls_name = obj.category.lower().strip()
|
969 |
+
# class_names.append(cls_name)
|
970 |
+
# class_ids.append(class_name_to_id[cls_name])
|
971 |
+
|
972 |
+
# input_boxes = np.array(input_boxes)
|
973 |
+
# class_ids = np.array(class_ids)
|
974 |
+
|
975 |
+
classes = [x.strip().lower() for x in TEXT_PROMPT.split('.') if x]
|
976 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
977 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
978 |
+
|
979 |
+
boxes = []
|
980 |
+
masks = []
|
981 |
+
confidences = []
|
982 |
+
class_names = []
|
983 |
+
class_ids = []
|
984 |
+
|
985 |
+
for idx, obj in enumerate(predictions):
|
986 |
+
boxes.append(obj.bbox)
|
987 |
+
masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
|
988 |
confidences.append(obj.score)
|
989 |
cls_name = obj.category.lower().strip()
|
990 |
class_names.append(cls_name)
|
991 |
class_ids.append(class_name_to_id[cls_name])
|
992 |
+
|
993 |
+
boxes = np.array(boxes)
|
994 |
+
masks = np.array(masks)
|
995 |
class_ids = np.array(class_ids)
|
996 |
+
labels = [
|
997 |
+
f"{class_name} {confidence:.2f}"
|
998 |
+
for class_name, confidence
|
999 |
+
in zip(class_names, confidences)
|
1000 |
+
]
|
1001 |
+
|
1002 |
# Initialize SAM2
|
1003 |
+
# torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
1004 |
+
# if torch.cuda.get_device_properties(0).major >= 8:
|
1005 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
1006 |
+
# torch.backends.cudnn.allow_tf32 = True
|
1007 |
|
1008 |
+
# sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
|
1009 |
+
# sam2_predictor = SAM2ImagePredictor(sam2_model)
|
1010 |
+
# sam2_predictor.set_image(input_image)
|
1011 |
|
1012 |
# sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
|
1013 |
|
1014 |
|
1015 |
# Get masks from SAM2
|
1016 |
+
# masks, scores, logits = sam2_predictor.predict(
|
1017 |
+
# point_coords=None,
|
1018 |
+
# point_labels=None,
|
1019 |
+
# box=input_boxes,
|
1020 |
+
# multimask_output=False,
|
1021 |
+
# )
|
1022 |
+
|
1023 |
if masks.ndim == 4:
|
1024 |
masks = masks.squeeze(1)
|
1025 |
|
1026 |
# Create visualization
|
1027 |
+
# labels = [f"{class_name} {confidence:.2f}"
|
1028 |
+
# for class_name, confidence in zip(class_names, confidences)]
|
1029 |
|
1030 |
+
# detections = sv.Detections(
|
1031 |
+
# xyxy=input_boxes,
|
1032 |
+
# mask=masks.astype(bool),
|
1033 |
+
# class_id=class_ids
|
1034 |
+
# )
|
1035 |
+
|
1036 |
detections = sv.Detections(
|
1037 |
+
xyxy = boxes,
|
1038 |
+
mask = masks.astype(bool),
|
1039 |
+
class_id = class_ids,
|
1040 |
+
)
|
1041 |
+
|
1042 |
box_annotator = sv.BoxAnnotator()
|
1043 |
label_annotator = sv.LabelAnnotator()
|
1044 |
mask_annotator = sv.MaskAnnotator()
|
|
|
1193 |
with gr.Group():
|
1194 |
gr.Markdown("Extract Foreground")
|
1195 |
input_image = gr.Image(type="numpy", label="Input Image", height=480)
|
1196 |
+
with gr.Row():
|
1197 |
+
with gr.Group():
|
1198 |
find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
|
1199 |
text_prompt = gr.Textbox(
|
1200 |
label="Text Prompt",
|
|
|
1349 |
|
1350 |
# # return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
|
1351 |
|
1352 |
+
def update_position(self, background, x_pos, y_pos, scale, *args):
|
1353 |
+
if self.original_bg is None:
|
1354 |
+
print("No original background set.")
|
1355 |
return None
|
1356 |
+
fresh_bg = self.original_bg.copy() # Start from a clean original background
|
|
|
|
|
1357 |
return mask_mover.create_composite(fresh_bg, float(x_pos), float(y_pos), float(scale))
|
1358 |
|
|
|
1359 |
|
1360 |
+
|
1361 |
# Create an instance of BackgroundManager
|
1362 |
bg_manager = BackgroundManager()
|
1363 |
|
1364 |
+
def update_position_wrapper(background, x_pos, y_pos, scale):
|
1365 |
+
return bg_manager.update_position(background, x_pos, y_pos, scale)
|
1366 |
+
|
1367 |
+
|
1368 |
+
|
1369 |
+
# def update_position(background, x_pos, y_pos, scale):
|
1370 |
+
# if background is None:
|
1371 |
+
# return None
|
1372 |
+
# fresh_bg = bg_manager.original_bg.copy() # Start from a clean original background
|
1373 |
+
# # Composite the extracted foreground onto fresh_bg
|
1374 |
+
# return mask_mover.create_composite(fresh_bg, float(x_pos), float(y_pos), float(scale))
|
1375 |
|
1376 |
|
1377 |
|
|
|
1383 |
)
|
1384 |
|
1385 |
x_slider.change(
|
1386 |
+
fn=update_position_wrapper,
|
1387 |
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
|
1388 |
outputs=[input_bg]
|
1389 |
)
|
1390 |
|
1391 |
y_slider.change(
|
1392 |
+
fn=update_position_wrapper,
|
1393 |
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
|
1394 |
outputs=[input_bg]
|
1395 |
)
|
1396 |
|
1397 |
fg_scale_slider.change(
|
1398 |
+
fn=update_position_wrapper,
|
1399 |
inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
|
1400 |
outputs=[input_bg]
|
1401 |
)
|