Rausda6 commited on
Commit
34b1e95
·
verified ·
1 Parent(s): 0f5112d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -37
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
- import json, os, re, traceback, contextlib
3
- from typing import Any, List, Dict
4
 
5
  import spaces
6
  import torch
@@ -132,6 +132,10 @@ def run_inference_localization(
132
  pil_image_for_processing: Image.Image,
133
  device: str,
134
  dtype: torch.dtype,
 
 
 
 
135
  ) -> str:
136
  text_prompt = apply_chat_template_compat(processor, messages_for_template)
137
 
@@ -151,12 +155,15 @@ def run_inference_localization(
151
  else:
152
  amp_ctx = contextlib.nullcontext()
153
 
 
 
 
 
 
 
 
154
  with amp_ctx:
155
- generated_ids = model.generate(
156
- **inputs,
157
- max_new_tokens=128,
158
- do_sample=False,
159
- )
160
 
161
  generated_ids_trimmed = trim_generated(generated_ids, inputs)
162
  decoded_output = batch_decode_compat(
@@ -167,10 +174,159 @@ def run_inference_localization(
167
  )
168
  return decoded_output[0] if decoded_output else ""
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  # --- Gradio processing function (ZeroGPU-visible) ---
171
  # Decorate the function Gradio calls so Spaces detects a GPU entry point.
172
  @spaces.GPU(duration=120) # keep GPU attached briefly between calls (seconds)
173
- def predict_click_location(input_pil_image: Image.Image, instruction: str):
 
 
 
 
 
 
 
 
174
  if not model_loaded or not processor or not model:
175
  return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a"
176
  if not input_pil_image:
@@ -220,33 +376,70 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
220
  # 2) Build messages with image + instruction
221
  messages = get_localization_prompt(resized_image, instruction)
222
 
223
- # 3) Run inference
224
  try:
225
- coordinates_str = run_inference_localization(messages, resized_image, device, dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  except Exception as e:
227
  traceback.print_exc()
228
  return f"Error during model inference: {e}", resized_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}"
229
 
230
- # 4) Parse coordinates and draw marker
231
- output_image_with_click = resized_image.copy().convert("RGB")
232
- match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
233
- if match:
234
- try:
235
- x = int(match.group(1))
236
- y = int(match.group(2))
237
- draw = ImageDraw.Draw(output_image_with_click)
238
- radius = max(5, min(resized_width // 100, resized_height // 100, 15))
239
- bbox = (x - radius, y - radius, x + radius, y + radius)
240
- draw.ellipse(bbox, outline="red", width=max(2, radius // 4))
241
- print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})")
242
- except Exception as e:
243
- print(f"Error drawing on image: {e}")
244
- traceback.print_exc()
245
- else:
246
- print(f"Could not parse 'Click(x, y)' from model output: {coordinates_str}")
247
-
248
- return coordinates_str, output_image_with_click, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
249
-
250
  # --- Load Example Data ---
251
  example_image = None
252
  example_instruction = "Enter the server address readyforquantum.com to check its security"
@@ -292,16 +485,21 @@ else:
292
  placeholder="e.g., Click the 'Login' button",
293
  info="Type the action you want the model to localize on the image."
294
  )
 
 
 
 
 
295
  submit_button = gr.Button("Localize Click", variant="primary")
296
 
297
  with gr.Column(scale=1):
298
  output_coords_component = gr.Textbox(
299
- label="Predicted Coordinates (Format: Click(x, y))",
300
  interactive=False
301
  )
302
  output_image_component = gr.Image(
303
  type="pil",
304
- label="Image with Predicted Click Point",
305
  height=400,
306
  interactive=False
307
  )
@@ -313,8 +511,16 @@ else:
313
 
314
  if example_image:
315
  gr.Examples(
316
- examples=[[example_image, example_instruction]],
317
- inputs=[input_image_component, instruction_component],
 
 
 
 
 
 
 
 
318
  outputs=[output_coords_component, output_image_component, runtime_info],
319
  fn=predict_click_location,
320
  cache_examples="lazy",
@@ -322,11 +528,17 @@ else:
322
 
323
  submit_button.click(
324
  fn=predict_click_location,
325
- inputs=[input_image_component, instruction_component],
 
 
 
 
 
 
 
 
326
  outputs=[output_coords_component, output_image_component, runtime_info]
327
  )
328
 
329
  if __name__ == "__main__":
330
- # Do NOT pass 'concurrency_count' or ZeroGPU-specific launch args.
331
  demo.launch(debug=True)
332
-
 
1
  import gradio as gr
2
+ import json, os, re, traceback, contextlib, math, random
3
+ from typing import Any, List, Dict, Optional, Tuple
4
 
5
  import spaces
6
  import torch
 
132
  pil_image_for_processing: Image.Image,
133
  device: str,
134
  dtype: torch.dtype,
135
+ do_sample: bool = False,
136
+ temperature: float = 0.6,
137
+ top_p: float = 0.9,
138
+ max_new_tokens: int = 128,
139
  ) -> str:
140
  text_prompt = apply_chat_template_compat(processor, messages_for_template)
141
 
 
155
  else:
156
  amp_ctx = contextlib.nullcontext()
157
 
158
+ gen_kwargs = dict(
159
+ max_new_tokens=max_new_tokens,
160
+ do_sample=do_sample,
161
+ temperature=temperature,
162
+ top_p=top_p,
163
+ )
164
+
165
  with amp_ctx:
166
+ generated_ids = model.generate(**inputs, **gen_kwargs)
 
 
 
 
167
 
168
  generated_ids_trimmed = trim_generated(generated_ids, inputs)
169
  decoded_output = batch_decode_compat(
 
174
  )
175
  return decoded_output[0] if decoded_output else ""
176
 
177
+ # ---------- Confidence helpers ----------
178
+ CLICK_RE = re.compile(r"Click\((\d+),\s*(\d+)\)")
179
+
180
+ def parse_click(s: str) -> Optional[Tuple[int, int]]:
181
+ m = CLICK_RE.search(s)
182
+ if not m:
183
+ return None
184
+ try:
185
+ return int(m.group(1)), int(m.group(2))
186
+ except Exception:
187
+ return None
188
+
189
+ @torch.inference_mode()
190
+ def sample_clicks(
191
+ messages: List[dict],
192
+ img: Image.Image,
193
+ device: str,
194
+ dtype: torch.dtype,
195
+ n_samples: int = 7,
196
+ temperature: float = 0.6,
197
+ top_p: float = 0.9,
198
+ seed: Optional[int] = None,
199
+ ) -> List[Optional[Tuple[int, int]]]:
200
+ """
201
+ Run multiple stochastic decodes to estimate self-consistency.
202
+ Returns a list of (x,y) or None (if parsing failed) for each sample.
203
+ """
204
+ clicks: List[Optional[Tuple[int, int]]] = []
205
+ # If model respects torch random, set seed for reproducibility (optional)
206
+ if seed is not None:
207
+ torch.manual_seed(seed)
208
+ random.seed(seed)
209
+ for i in range(n_samples):
210
+ # Vary seed slightly each iteration to avoid identical sampling patterns
211
+ if seed is not None:
212
+ torch.manual_seed(seed + i + 1)
213
+ random.seed((seed + i + 1) & 0xFFFFFFFF)
214
+ out = run_inference_localization(
215
+ messages, img, device, dtype,
216
+ do_sample=True, temperature=temperature, top_p=top_p
217
+ )
218
+ clicks.append(parse_click(out))
219
+ return clicks
220
+
221
+ def cluster_and_confidence(
222
+ clicks: List[Optional[Tuple[int,int]]],
223
+ img_w: int,
224
+ img_h: int,
225
+ ) -> Dict[str, Any]:
226
+ """
227
+ Simple robust consensus:
228
+ - Keep only valid points
229
+ - Compute median point (x_med, y_med)
230
+ - Compute distances to median
231
+ - Inlier threshold = max(8 px, 2% of min(img_w, img_h))
232
+ - Confidence = (#inliers / #total_samples) * clamp(1 - (rms_inlier_dist / thr), 0, 1)
233
+ Returns dict with consensus point, confidence, dispersion, and counts.
234
+ """
235
+ valid = [xy for xy in clicks if xy is not None]
236
+ total = len(clicks)
237
+ if total == 0:
238
+ return dict(ok=False, reason="no_samples")
239
+
240
+ if not valid:
241
+ return dict(ok=False, reason="no_valid_points", total=total)
242
+
243
+ xs = sorted([x for x, _ in valid])
244
+ ys = sorted([y for _, y in valid])
245
+ mid = len(valid) // 2
246
+ if len(valid) % 2 == 1:
247
+ x_med = xs[mid]
248
+ y_med = ys[mid]
249
+ else:
250
+ x_med = (xs[mid - 1] + xs[mid]) // 2
251
+ y_med = (ys[mid - 1] + ys[mid]) // 2
252
+
253
+ thr = max(8.0, 0.02 * min(img_w, img_h)) # ~2% of smaller side, at least 8 px
254
+ dists = [math.hypot(x - x_med, y - y_med) for (x, y) in valid]
255
+ inliers = [(xy, d) for xy, d in zip(valid, dists) if d <= thr]
256
+ outliers = [(xy, d) for xy, d in zip(valid, dists) if d > thr]
257
+ inlier_count = len(inliers)
258
+
259
+ # RMS of inlier distances (0 if perfect agreement)
260
+ if inliers:
261
+ rms = math.sqrt(sum(d*d for _, d in inliers) / len(inliers))
262
+ else:
263
+ rms = float("inf")
264
+
265
+ # Confidence: agreement ratio * sharpness factor
266
+ if inliers:
267
+ sharp = max(0.0, min(1.0, 1.0 - (rms / thr)))
268
+ else:
269
+ sharp = 0.0
270
+ confidence = (inlier_count / total) * sharp
271
+
272
+ return dict(
273
+ ok=True,
274
+ x=x_med, y=y_med,
275
+ confidence=confidence,
276
+ total_samples=total,
277
+ valid_samples=len(valid),
278
+ inliers=inlier_count,
279
+ outliers=len(outliers),
280
+ sigma_px=rms if math.isfinite(rms) else None,
281
+ inlier_threshold_px=thr,
282
+ all_points=valid,
283
+ inlier_points=[xy for xy,_ in inliers],
284
+ outlier_points=[xy for xy,_ in outliers],
285
+ )
286
+
287
+ def draw_samples(
288
+ base_img: Image.Image,
289
+ consensus_xy: Optional[Tuple[int,int]],
290
+ inliers: List[Tuple[int,int]],
291
+ outliers: List[Tuple[int,int]],
292
+ ring_color: str = "red",
293
+ ) -> Image.Image:
294
+ """
295
+ Overlay all sampled points: green=inliers, red=outliers, plus a ring for consensus.
296
+ """
297
+ img = base_img.copy().convert("RGB")
298
+ draw = ImageDraw.Draw(img)
299
+ w, h = img.size
300
+ # Dot radius scales with image size
301
+ r = max(3, min(w, h) // 200)
302
+
303
+ # Draw inliers
304
+ for (x, y) in inliers:
305
+ draw.ellipse((x - r, y - r, x + r, y + r), fill="green", outline=None)
306
+
307
+ # Draw outliers
308
+ for (x, y) in outliers:
309
+ draw.ellipse((x - r, y - r, x + r, y + r), fill="red", outline=None)
310
+
311
+ # Consensus ring
312
+ if consensus_xy is not None:
313
+ cx, cy = consensus_xy
314
+ ring_r = max(5, min(w, h) // 100, r * 3)
315
+ draw.ellipse((cx - ring_r, cy - ring_r, cx + ring_r, cy + ring_r), outline=ring_color, width=max(2, ring_r // 4))
316
+ return img
317
+
318
  # --- Gradio processing function (ZeroGPU-visible) ---
319
  # Decorate the function Gradio calls so Spaces detects a GPU entry point.
320
  @spaces.GPU(duration=120) # keep GPU attached briefly between calls (seconds)
321
+ def predict_click_location(
322
+ input_pil_image: Image.Image,
323
+ instruction: str,
324
+ estimate_confidence: bool = True,
325
+ num_samples: int = 7,
326
+ temperature: float = 0.6,
327
+ top_p: float = 0.9,
328
+ seed: Optional[int] = None,
329
+ ):
330
  if not model_loaded or not processor or not model:
331
  return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a"
332
  if not input_pil_image:
 
376
  # 2) Build messages with image + instruction
377
  messages = get_localization_prompt(resized_image, instruction)
378
 
379
+ # 3) Inference and (optionally) confidence estimation
380
  try:
381
+ if estimate_confidence and num_samples >= 3:
382
+ # Monte-Carlo sampling
383
+ clicks = sample_clicks(
384
+ messages, resized_image, device, dtype,
385
+ n_samples=int(num_samples),
386
+ temperature=float(temperature),
387
+ top_p=float(top_p),
388
+ seed=seed
389
+ )
390
+ summary = cluster_and_confidence(clicks, resized_image.width, resized_image.height)
391
+
392
+ if not summary.get("ok", False):
393
+ # Fallback: deterministic decode
394
+ coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False)
395
+ out_img = resized_image.copy().convert("RGB")
396
+ match = CLICK_RE.search(coord_str or "")
397
+ if match:
398
+ x, y = int(match.group(1)), int(match.group(2))
399
+ out_img = draw_samples(out_img, (x, y), [], [])
400
+ coords_text = f"{coord_str} | confidence=0.00 (fallback)"
401
+ return coords_text, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
402
+
403
+ # Build final string + visualization
404
+ x, y = int(summary["x"]), int(summary["y"])
405
+ conf = summary["confidence"]
406
+ inliers = summary["inlier_points"]
407
+ outliers = summary["outlier_points"]
408
+ sigma = summary["sigma_px"]
409
+ thr = summary["inlier_threshold_px"]
410
+ total = summary["total_samples"]
411
+ valid = summary["valid_samples"]
412
+
413
+ # Compose output string in the same canonical format plus diagnostics
414
+ coord_str = f"Click({x}, {y})"
415
+ diag = (
416
+ f"confidence={conf:.2f} | samples(valid/total)={valid}/{total} | "
417
+ f"inliers={len(inliers)} | σ={sigma:.1f}px | thr={thr:.1f}px | "
418
+ f"T={temperature:.2f}, p={top_p:.2f}"
419
+ )
420
+
421
+ # Draw: all samples + consensus ring
422
+ out_img = draw_samples(resized_image, (x, y), inliers, outliers)
423
+ return f"{coord_str} | {diag}", out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
424
+
425
+ else:
426
+ # Fast deterministic single pass (no confidence)
427
+ coord_str = run_inference_localization(messages, resized_image, device, dtype, do_sample=False)
428
+ out_img = resized_image.copy().convert("RGB")
429
+ match = CLICK_RE.search(coord_str or "")
430
+ if match:
431
+ x = int(match.group(1))
432
+ y = int(match.group(2))
433
+ # draw a simple ring around the predicted point
434
+ out_img = draw_samples(out_img, (x, y), [], [])
435
+ else:
436
+ print(f"Could not parse 'Click(x, y)' from model output: {coord_str}")
437
+ return coord_str, out_img, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
438
+
439
  except Exception as e:
440
  traceback.print_exc()
441
  return f"Error during model inference: {e}", resized_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}"
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  # --- Load Example Data ---
444
  example_image = None
445
  example_instruction = "Enter the server address readyforquantum.com to check its security"
 
485
  placeholder="e.g., Click the 'Login' button",
486
  info="Type the action you want the model to localize on the image."
487
  )
488
+ estimate_conf = gr.Checkbox(value=True, label="Estimate confidence (slower)")
489
+ num_samples_slider = gr.Slider(3, 15, value=7, step=1, label="Samples (for confidence)")
490
+ temperature_slider = gr.Slider(0.2, 1.2, value=0.6, step=0.05, label="Temperature")
491
+ top_p_slider = gr.Slider(0.5, 0.99, value=0.9, step=0.01, label="Top-p")
492
+ seed_box = gr.Number(value=None, precision=0, label="Seed (optional, for reproducibility)")
493
  submit_button = gr.Button("Localize Click", variant="primary")
494
 
495
  with gr.Column(scale=1):
496
  output_coords_component = gr.Textbox(
497
+ label="Predicted Coordinates + Confidence",
498
  interactive=False
499
  )
500
  output_image_component = gr.Image(
501
  type="pil",
502
+ label="Image with Samples (green=inliers, red=outliers) and Final Ring",
503
  height=400,
504
  interactive=False
505
  )
 
511
 
512
  if example_image:
513
  gr.Examples(
514
+ examples=[[example_image, example_instruction, True, 7, 0.6, 0.9, None]],
515
+ inputs=[
516
+ input_image_component,
517
+ instruction_component,
518
+ estimate_conf,
519
+ num_samples_slider,
520
+ temperature_slider,
521
+ top_p_slider,
522
+ seed_box,
523
+ ],
524
  outputs=[output_coords_component, output_image_component, runtime_info],
525
  fn=predict_click_location,
526
  cache_examples="lazy",
 
528
 
529
  submit_button.click(
530
  fn=predict_click_location,
531
+ inputs=[
532
+ input_image_component,
533
+ instruction_component,
534
+ estimate_conf,
535
+ num_samples_slider,
536
+ temperature_slider,
537
+ top_p_slider,
538
+ seed_box,
539
+ ],
540
  outputs=[output_coords_component, output_image_component, runtime_info]
541
  )
542
 
543
  if __name__ == "__main__":
 
544
  demo.launch(debug=True)