Spaces:
Runtime error
Runtime error
Update gradio_demo.py
Browse files- gradio_demo.py +104 -2
gradio_demo.py
CHANGED
@@ -9,6 +9,8 @@ import db_examples
|
|
9 |
import datetime
|
10 |
from pathlib import Path
|
11 |
from io import BytesIO
|
|
|
|
|
12 |
|
13 |
from PIL import Image
|
14 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
@@ -23,6 +25,7 @@ from enum import Enum
|
|
23 |
from torch.hub import download_url_to_file
|
24 |
import tempfile
|
25 |
|
|
|
26 |
from sam2.build_sam import build_sam2
|
27 |
|
28 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
@@ -39,7 +42,6 @@ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
|
|
39 |
|
40 |
|
41 |
|
42 |
-
|
43 |
# from FLORENCE
|
44 |
import spaces
|
45 |
import supervision as sv
|
@@ -49,6 +51,7 @@ from PIL import Image
|
|
49 |
from utils.sam import load_sam_image_model, run_sam_inference
|
50 |
|
51 |
|
|
|
52 |
try:
|
53 |
import xformers
|
54 |
import xformers.ops
|
@@ -83,6 +86,9 @@ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_l
|
|
83 |
model = model.to(device)
|
84 |
model.eval()
|
85 |
|
|
|
|
|
|
|
86 |
# Change UNet
|
87 |
|
88 |
with torch.no_grad():
|
@@ -826,8 +832,10 @@ def compress_image(image):
|
|
826 |
compressed_img = np.array(Image.open("compressed_image.jpg"))
|
827 |
return compressed_img
|
828 |
|
|
|
829 |
@spaces.GPU(duration=60)
|
830 |
-
@torch.inference_mode
|
|
|
831 |
def process_image(input_image, input_text):
|
832 |
"""Main processing function for the Gradio interface"""
|
833 |
|
@@ -839,6 +847,8 @@ def process_image(input_image, input_text):
|
|
839 |
OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
|
840 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
841 |
|
|
|
|
|
842 |
# Initialize DDS client
|
843 |
config = Config(API_TOKEN)
|
844 |
client = Client(config)
|
@@ -933,6 +943,98 @@ def process_image(input_image, input_text):
|
|
933 |
|
934 |
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
935 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
|
937 |
|
938 |
block = gr.Blocks().queue()
|
|
|
9 |
import datetime
|
10 |
from pathlib import Path
|
11 |
from io import BytesIO
|
12 |
+
from hydra import initialize, compose
|
13 |
+
|
14 |
|
15 |
from PIL import Image
|
16 |
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
|
|
25 |
from torch.hub import download_url_to_file
|
26 |
import tempfile
|
27 |
|
28 |
+
|
29 |
from sam2.build_sam import build_sam2
|
30 |
|
31 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
|
42 |
|
43 |
|
44 |
|
|
|
45 |
# from FLORENCE
|
46 |
import spaces
|
47 |
import supervision as sv
|
|
|
51 |
from utils.sam import load_sam_image_model, run_sam_inference
|
52 |
|
53 |
|
54 |
+
|
55 |
try:
|
56 |
import xformers
|
57 |
import xformers.ops
|
|
|
86 |
model = model.to(device)
|
87 |
model.eval()
|
88 |
|
89 |
+
SAM_IMAGE_MODEL = load_sam_image_model(device=device)
|
90 |
+
|
91 |
+
|
92 |
# Change UNet
|
93 |
|
94 |
with torch.no_grad():
|
|
|
832 |
compressed_img = np.array(Image.open("compressed_image.jpg"))
|
833 |
return compressed_img
|
834 |
|
835 |
+
|
836 |
@spaces.GPU(duration=60)
|
837 |
+
@torch.inference_mode
|
838 |
+
@hydra.main(config_path="/home/user/app/configs", config_name="sam2_hiera_l")
|
839 |
def process_image(input_image, input_text):
|
840 |
"""Main processing function for the Gradio interface"""
|
841 |
|
|
|
847 |
OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
|
848 |
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
849 |
|
850 |
+
|
851 |
+
|
852 |
# Initialize DDS client
|
853 |
config = Config(API_TOKEN)
|
854 |
client = Client(config)
|
|
|
943 |
|
944 |
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
945 |
|
946 |
+
|
947 |
+
else:
|
948 |
+
# Run DINO-X detection
|
949 |
+
task = DinoxTask(
|
950 |
+
image_url=image_url,
|
951 |
+
prompts=[TextPrompt(text=input_text)]
|
952 |
+
)
|
953 |
+
client.run_task(task)
|
954 |
+
result = task.result
|
955 |
+
objects = result.objects
|
956 |
+
|
957 |
+
|
958 |
+
|
959 |
+
for obj in objects:
|
960 |
+
input_boxes.append(obj.bbox)
|
961 |
+
confidences.append(obj.score)
|
962 |
+
cls_name = obj.category.lower().strip()
|
963 |
+
class_names.append(cls_name)
|
964 |
+
class_ids.append(class_name_to_id[cls_name])
|
965 |
+
|
966 |
+
input_boxes = np.array(input_boxes)
|
967 |
+
class_ids = np.array(class_ids)
|
968 |
+
|
969 |
+
# Initialize SAM2
|
970 |
+
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
|
971 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
972 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
973 |
+
torch.backends.cudnn.allow_tf32 = True
|
974 |
+
|
975 |
+
sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
|
976 |
+
sam2_predictor = SAM2ImagePredictor(sam2_model)
|
977 |
+
sam2_predictor.set_image(input_image)
|
978 |
+
|
979 |
+
# sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
|
980 |
+
|
981 |
+
|
982 |
+
# Get masks from SAM2
|
983 |
+
masks, scores, logits = sam2_predictor.predict(
|
984 |
+
point_coords=None,
|
985 |
+
point_labels=None,
|
986 |
+
box=input_boxes,
|
987 |
+
multimask_output=False,
|
988 |
+
)
|
989 |
+
if masks.ndim == 4:
|
990 |
+
masks = masks.squeeze(1)
|
991 |
+
|
992 |
+
# Create visualization
|
993 |
+
labels = [f"{class_name} {confidence:.2f}"
|
994 |
+
for class_name, confidence in zip(class_names, confidences)]
|
995 |
+
|
996 |
+
detections = sv.Detections(
|
997 |
+
xyxy=input_boxes,
|
998 |
+
mask=masks.astype(bool),
|
999 |
+
class_id=class_ids
|
1000 |
+
)
|
1001 |
+
|
1002 |
+
box_annotator = sv.BoxAnnotator()
|
1003 |
+
label_annotator = sv.LabelAnnotator()
|
1004 |
+
mask_annotator = sv.MaskAnnotator()
|
1005 |
+
|
1006 |
+
annotated_frame = input_image.copy()
|
1007 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
|
1008 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
1009 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
1010 |
+
|
1011 |
+
# Create transparent mask for first detected object
|
1012 |
+
if len(detections) > 0:
|
1013 |
+
# Get first mask
|
1014 |
+
first_mask = detections.mask[0]
|
1015 |
+
|
1016 |
+
# Get original RGB image
|
1017 |
+
img = input_image.copy()
|
1018 |
+
H, W, C = img.shape
|
1019 |
+
|
1020 |
+
# Create RGBA image
|
1021 |
+
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
1022 |
+
alpha[first_mask] = 255
|
1023 |
+
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
1024 |
+
|
1025 |
+
# Crop to mask bounds to minimize image size
|
1026 |
+
y_indices, x_indices = np.where(first_mask)
|
1027 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
1028 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
1029 |
+
|
1030 |
+
# Crop the RGBA image
|
1031 |
+
cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
|
1032 |
+
|
1033 |
+
# Set extracted foreground for mask mover
|
1034 |
+
mask_mover.set_extracted_fg(cropped_rgba)
|
1035 |
+
|
1036 |
+
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
1037 |
+
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
1038 |
|
1039 |
|
1040 |
block = gr.Blocks().queue()
|