Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7877974
1
Parent(s):
bd1381f
Debug mask images
Browse files
app.py
CHANGED
|
@@ -225,100 +225,99 @@ def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> N
|
|
| 225 |
|
| 226 |
def run_grounded_sam(input_image):
|
| 227 |
"""Main function to run GroundingDINO and SAM-HQ"""
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
scribble = np.array(input_image["mask"])
|
| 242 |
-
image_pil = input_image["image"].convert("RGB")
|
| 243 |
-
else:
|
| 244 |
-
# Direct image input
|
| 245 |
-
image_pil = input_image.convert("RGB") if input_image else None
|
| 246 |
-
scribble = None
|
| 247 |
-
|
| 248 |
-
if image_pil is None:
|
| 249 |
-
logger.error("No input image provided")
|
| 250 |
-
return [Image.new('RGB', (400, 300), color='gray')]
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
# Apply non-maximum suppression if we have multiple boxes
|
| 273 |
-
if boxes_filt.size(0) > 1:
|
| 274 |
-
logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
| 275 |
-
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
| 276 |
-
boxes_filt = boxes_filt[nms_idx]
|
| 277 |
-
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
| 278 |
-
logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
|
| 279 |
-
|
| 280 |
-
# Load SAM model
|
| 281 |
-
ModelManager.load_model('sam')
|
| 282 |
-
sam_predictor = ModelManager.get_model('sam_predictor')
|
| 283 |
-
|
| 284 |
-
# Set image for SAM
|
| 285 |
-
image = np.array(image_pil)
|
| 286 |
-
sam_predictor.set_image(image)
|
| 287 |
-
|
| 288 |
-
# Run SAM
|
| 289 |
-
# Use boxes for these task types
|
| 290 |
-
if boxes_filt.size(0) == 0:
|
| 291 |
-
logger.warning("No boxes detected")
|
| 292 |
-
return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
|
| 293 |
-
|
| 294 |
-
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
|
| 316 |
-
|
| 317 |
|
| 318 |
-
|
| 319 |
|
| 320 |
-
except Exception as e:
|
| 321 |
-
|
| 322 |
# # Return original image on error
|
| 323 |
# if isinstance(input_image, dict) and "image" in input_image:
|
| 324 |
# return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
|
|
|
|
| 225 |
|
| 226 |
def run_grounded_sam(input_image):
|
| 227 |
"""Main function to run GroundingDINO and SAM-HQ"""
|
| 228 |
+
# Create output directory
|
| 229 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 230 |
+
text_prompt = 'car'
|
| 231 |
+
task_type = 'text'
|
| 232 |
+
box_threshold = 0.3
|
| 233 |
+
text_threshold = 0.25
|
| 234 |
+
iou_threshold = 0.8
|
| 235 |
+
hq_token_only = True
|
| 236 |
+
|
| 237 |
+
# Process input image
|
| 238 |
+
if isinstance(input_image, dict):
|
| 239 |
+
# Input from gradio sketch component
|
| 240 |
+
scribble = np.array(input_image["mask"])
|
| 241 |
+
image_pil = input_image["image"].convert("RGB")
|
| 242 |
+
else:
|
| 243 |
+
# Direct image input
|
| 244 |
+
image_pil = input_image.convert("RGB") if input_image else None
|
| 245 |
+
scribble = None
|
| 246 |
|
| 247 |
+
if image_pil is None:
|
| 248 |
+
logger.error("No input image provided")
|
| 249 |
+
return [Image.new('RGB', (400, 300), color='gray')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
# Transform image for GroundingDINO
|
| 252 |
+
transformed_image = transform_image(image_pil)
|
| 253 |
+
|
| 254 |
+
# Load models as needed
|
| 255 |
+
ModelManager.load_model('groundingdino')
|
| 256 |
+
size = image_pil.size
|
| 257 |
+
H, W = size[1], size[0]
|
| 258 |
+
|
| 259 |
+
# Run GroundingDINO with provided text
|
| 260 |
+
boxes_filt, scores, pred_phrases = get_grounding_output(
|
| 261 |
+
transformed_image, text_prompt, box_threshold, text_threshold
|
| 262 |
+
)
|
| 263 |
|
| 264 |
+
if boxes_filt is not None:
|
| 265 |
+
# Scale boxes to image dimensions
|
| 266 |
+
for i in range(boxes_filt.size(0)):
|
| 267 |
+
boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
|
| 268 |
+
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
| 269 |
+
boxes_filt[i][2:] += boxes_filt[i][:2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
+
# Apply non-maximum suppression if we have multiple boxes
|
| 272 |
+
if boxes_filt.size(0) > 1:
|
| 273 |
+
logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
|
| 274 |
+
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
|
| 275 |
+
boxes_filt = boxes_filt[nms_idx]
|
| 276 |
+
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
|
| 277 |
+
logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
|
| 278 |
+
|
| 279 |
+
# Load SAM model
|
| 280 |
+
ModelManager.load_model('sam')
|
| 281 |
+
sam_predictor = ModelManager.get_model('sam_predictor')
|
| 282 |
+
|
| 283 |
+
# Set image for SAM
|
| 284 |
+
image = np.array(image_pil)
|
| 285 |
+
sam_predictor.set_image(image)
|
| 286 |
+
|
| 287 |
+
# Run SAM
|
| 288 |
+
# Use boxes for these task types
|
| 289 |
+
if boxes_filt.size(0) == 0:
|
| 290 |
+
logger.warning("No boxes detected")
|
| 291 |
+
return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
|
| 292 |
+
|
| 293 |
+
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
|
| 294 |
+
|
| 295 |
+
masks, _, _ = sam_predictor.predict_torch(
|
| 296 |
+
point_coords=None,
|
| 297 |
+
point_labels=None,
|
| 298 |
+
boxes=transformed_boxes,
|
| 299 |
+
multimask_output=False,
|
| 300 |
+
hq_token_only=hq_token_only,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Create mask image
|
| 304 |
+
mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
|
| 305 |
+
mask_draw = ImageDraw.Draw(mask_image)
|
| 306 |
+
|
| 307 |
+
# Draw masks
|
| 308 |
+
for mask in masks:
|
| 309 |
+
draw_mask(mask[0].cpu().numpy(), mask_draw)
|
| 310 |
+
|
| 311 |
+
# Draw boxes and points on original image
|
| 312 |
+
image_draw = ImageDraw.Draw(image_pil)
|
| 313 |
|
| 314 |
+
for box, label in zip(boxes_filt, pred_phrases):
|
| 315 |
+
draw_box(box, image_draw, label)
|
| 316 |
|
| 317 |
+
return mask_image
|
| 318 |
|
| 319 |
+
# except Exception as e:
|
| 320 |
+
# logger.error(f"Error in run_grounded_sam: {e}")
|
| 321 |
# # Return original image on error
|
| 322 |
# if isinstance(input_image, dict) and "image" in input_image:
|
| 323 |
# return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
|