Ashoka74 commited on
Commit
01f030e
Β·
verified Β·
1 Parent(s): 9615931

Update gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +104 -2
gradio_demo.py CHANGED
@@ -9,6 +9,8 @@ import db_examples
9
  import datetime
10
  from pathlib import Path
11
  from io import BytesIO
 
 
12
 
13
  from PIL import Image
14
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
@@ -23,6 +25,7 @@ from enum import Enum
23
  from torch.hub import download_url_to_file
24
  import tempfile
25
 
 
26
  from sam2.build_sam import build_sam2
27
 
28
  from sam2.sam2_image_predictor import SAM2ImagePredictor
@@ -39,7 +42,6 @@ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
39
 
40
 
41
 
42
-
43
  # from FLORENCE
44
  import spaces
45
  import supervision as sv
@@ -49,6 +51,7 @@ from PIL import Image
49
  from utils.sam import load_sam_image_model, run_sam_inference
50
 
51
 
 
52
  try:
53
  import xformers
54
  import xformers.ops
@@ -83,6 +86,9 @@ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_l
83
  model = model.to(device)
84
  model.eval()
85
 
 
 
 
86
  # Change UNet
87
 
88
  with torch.no_grad():
@@ -826,8 +832,10 @@ def compress_image(image):
826
  compressed_img = np.array(Image.open("compressed_image.jpg"))
827
  return compressed_img
828
 
 
829
  @spaces.GPU(duration=60)
830
- @torch.inference_mode()
 
831
  def process_image(input_image, input_text):
832
  """Main processing function for the Gradio interface"""
833
 
@@ -839,6 +847,8 @@ def process_image(input_image, input_text):
839
  OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
840
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
841
 
 
 
842
  # Initialize DDS client
843
  config = Config(API_TOKEN)
844
  client = Client(config)
@@ -933,6 +943,98 @@ def process_image(input_image, input_text):
933
 
934
  return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
 
937
 
938
  block = gr.Blocks().queue()
 
9
  import datetime
10
  from pathlib import Path
11
  from io import BytesIO
12
+ from hydra import initialize, compose
13
+
14
 
15
  from PIL import Image
16
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
 
25
  from torch.hub import download_url_to_file
26
  import tempfile
27
 
28
+
29
  from sam2.build_sam import build_sam2
30
 
31
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
42
 
43
 
44
 
 
45
  # from FLORENCE
46
  import spaces
47
  import supervision as sv
 
51
  from utils.sam import load_sam_image_model, run_sam_inference
52
 
53
 
54
+
55
  try:
56
  import xformers
57
  import xformers.ops
 
86
  model = model.to(device)
87
  model.eval()
88
 
89
+ SAM_IMAGE_MODEL = load_sam_image_model(device=device)
90
+
91
+
92
  # Change UNet
93
 
94
  with torch.no_grad():
 
832
  compressed_img = np.array(Image.open("compressed_image.jpg"))
833
  return compressed_img
834
 
835
+
836
  @spaces.GPU(duration=60)
837
+ @torch.inference_mode
838
+ @hydra.main(config_path="/home/user/app/configs", config_name="sam2_hiera_l")
839
  def process_image(input_image, input_text):
840
  """Main processing function for the Gradio interface"""
841
 
 
847
  OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
848
  OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
849
 
850
+
851
+
852
  # Initialize DDS client
853
  config = Config(API_TOKEN)
854
  client = Client(config)
 
943
 
944
  return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
945
 
946
+
947
+ else:
948
+ # Run DINO-X detection
949
+ task = DinoxTask(
950
+ image_url=image_url,
951
+ prompts=[TextPrompt(text=input_text)]
952
+ )
953
+ client.run_task(task)
954
+ result = task.result
955
+ objects = result.objects
956
+
957
+
958
+
959
+ for obj in objects:
960
+ input_boxes.append(obj.bbox)
961
+ confidences.append(obj.score)
962
+ cls_name = obj.category.lower().strip()
963
+ class_names.append(cls_name)
964
+ class_ids.append(class_name_to_id[cls_name])
965
+
966
+ input_boxes = np.array(input_boxes)
967
+ class_ids = np.array(class_ids)
968
+
969
+ # Initialize SAM2
970
+ torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
971
+ if torch.cuda.get_device_properties(0).major >= 8:
972
+ torch.backends.cuda.matmul.allow_tf32 = True
973
+ torch.backends.cudnn.allow_tf32 = True
974
+
975
+ sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
976
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
977
+ sam2_predictor.set_image(input_image)
978
+
979
+ # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
980
+
981
+
982
+ # Get masks from SAM2
983
+ masks, scores, logits = sam2_predictor.predict(
984
+ point_coords=None,
985
+ point_labels=None,
986
+ box=input_boxes,
987
+ multimask_output=False,
988
+ )
989
+ if masks.ndim == 4:
990
+ masks = masks.squeeze(1)
991
+
992
+ # Create visualization
993
+ labels = [f"{class_name} {confidence:.2f}"
994
+ for class_name, confidence in zip(class_names, confidences)]
995
+
996
+ detections = sv.Detections(
997
+ xyxy=input_boxes,
998
+ mask=masks.astype(bool),
999
+ class_id=class_ids
1000
+ )
1001
+
1002
+ box_annotator = sv.BoxAnnotator()
1003
+ label_annotator = sv.LabelAnnotator()
1004
+ mask_annotator = sv.MaskAnnotator()
1005
+
1006
+ annotated_frame = input_image.copy()
1007
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1008
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1009
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1010
+
1011
+ # Create transparent mask for first detected object
1012
+ if len(detections) > 0:
1013
+ # Get first mask
1014
+ first_mask = detections.mask[0]
1015
+
1016
+ # Get original RGB image
1017
+ img = input_image.copy()
1018
+ H, W, C = img.shape
1019
+
1020
+ # Create RGBA image
1021
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1022
+ alpha[first_mask] = 255
1023
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1024
+
1025
+ # Crop to mask bounds to minimize image size
1026
+ y_indices, x_indices = np.where(first_mask)
1027
+ y_min, y_max = y_indices.min(), y_indices.max()
1028
+ x_min, x_max = x_indices.min(), x_indices.max()
1029
+
1030
+ # Crop the RGBA image
1031
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
1032
+
1033
+ # Set extracted foreground for mask mover
1034
+ mask_mover.set_extracted_fg(cropped_rgba)
1035
+
1036
+ return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
1037
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1038
 
1039
 
1040
  block = gr.Blocks().queue()