#!/usr/bin/env python # furniture_bbox_to_files.py ──────────────────────────────────────── # Florence-2 + SAM-2 batch processor with retries *and* file-based images # -------------------------------------------------------------------- import os, json, random, time from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List import torch, supervision as sv from PIL import Image, ImageDraw, ImageColor, ImageOps from tqdm.auto import tqdm from datasets import load_dataset, Image as HFImage, disable_progress_bar # ───── global models ──────────────────────────────────────────────── from utils.florence import ( load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, ) from utils.sam import load_sam_image_model, run_sam_inference DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) # annotators _PALETTE = sv.ColorPalette.from_hex( ['#FF1493','#00BFFF','#FF6347','#FFD700','#32CD32','#8A2BE2']) BOX_ANN = sv.BoxAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX) MASK_ANN = sv.MaskAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX) LBL_ANN = sv.LabelAnnotator( color=_PALETTE, color_lookup=sv.ColorLookup.INDEX, text_position=sv.Position.CENTER_OF_MASS, text_color=sv.Color.from_hex("#000"), border_radius=5) # ───── config ─────────────────────────────────────────────────────── os.environ["TOKENIZERS_PARALLELISM"] = "false" disable_progress_bar() DATASET_NAME = "fotographerai/furniture_captioned_segment_prompt" SPLIT = "train" IMAGE_COL = "img2" PROMPT_COL = "segmenting_prompt" INFLATE_RANGE = (0.01, 0.05) FILL_COLOR = "#00FF00" TARGET_SIDE = 1500 QA_DIR = Path("bbox_review_recaptioned") GREEN_DIR = QA_DIR / "green"; GREEN_DIR.mkdir(parents=True, exist_ok=True) ANNO_DIR = QA_DIR / "anno"; ANNO_DIR.mkdir(parents=True, exist_ok=True) JSON_DIR = QA_DIR / "json"; JSON_DIR.mkdir(parents=True, exist_ok=True) MAX_WORKERS = 100 MAX_RETRIES = 5 RETRY_SLEEP = .3 FAILED_LOG = QA_DIR / "failed_rows.jsonl" PROMPT_MAP: dict[str,str] = {} # optional overrides # ───── helpers ────────────────────────────────────────────────────── def make_square(img: Image.Image, side: int = TARGET_SIDE) -> Image.Image: img = ImageOps.contain(img, (side, side)) pad_w, pad_h = side - img.width, side - img.height return ImageOps.expand(img, border=(pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2), fill=img.getpixel((0,0))) def img_to_file(img: Image.Image, fname: str, folder: Path) -> dict: path = folder / f"{fname}.png" if not path.exists(): img.save(path) return {"path": str(path), "bytes": None} # ───── core functions ─────────────────────────────────────────────── @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def detect_and_segment(img: Image.Image, prompts: str|List[str]) -> sv.Detections: if isinstance(prompts, str): prompts = [p.strip() for p in prompts.split(",") if p.strip()] all_dets = [] for p in prompts: _, res = run_florence_inference( model=FLORENCE_MODEL, processor=FLORENCE_PROC, device=DEVICE, image=img, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=p) d = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, res, img.size) all_dets.append(run_sam_inference(SAM_IMAGE_MODEL, img, d)) return sv.Detections.merge(all_dets) def fill_detected_bboxes(img: Image.Image, prompt: str, inflate_pct: float) -> tuple[Image.Image, sv.Detections]: dets = detect_and_segment(img, prompt) filled = img.copy() draw = ImageDraw.Draw(filled) rgb = ImageColor.getrgb(FILL_COLOR) w,h = img.size for box in dets.xyxy: x1,y1,x2,y2 = box.astype(float) dw,dh = (x2-x1)*inflate_pct, (y2-y1)*inflate_pct draw.rectangle([max(0,x1-dw), max(0,y1-dh), min(w,x2+dw), min(h,y2+dh)], fill=rgb) return filled, dets # ───── threaded worker ────────────────────────────────────────────── def process_row(idx: int, sample): prompt = PROMPT_MAP.get(sample[PROMPT_COL], sample[PROMPT_COL].split(",",1)[0].strip()) img_sq = make_square(sample[IMAGE_COL].convert("RGB")) for attempt in range(1, MAX_RETRIES+1): try: filled, dets = fill_detected_bboxes( img_sq, prompt, inflate_pct=random.uniform(*INFLATE_RANGE)) if len(dets.xyxy) == 0: raise ValueError("no detections") sid = f"{idx:06d}" json_p = JSON_DIR / f"{sid}_bbox.json" json_p.write_text(json.dumps({"xyxy": dets.xyxy.tolist()})) anno = img_sq.copy() for ann in (MASK_ANN, BOX_ANN, LABEL_ANN): anno = ann.annotate(anno, dets) return ("ok", img_to_file(filled, sid, GREEN_DIR), img_to_file(anno, sid, ANNO_DIR), json_p.read_text()) except Exception as e: if attempt < MAX_RETRIES: time.sleep(RETRY_SLEEP) else: return ("fail", str(e)) # ───── run batch ──────────────────────────────────────────────────── ds = load_dataset(DATASET_NAME, split=SPLIT, streaming=False) N = len(ds) print("Rows:", N) filled_col, anno_col, json_col = [None]*N, [None]*N, [None]*N fails = 0 with ThreadPoolExecutor(MAX_WORKERS) as pool: fut2idx = {pool.submit(process_row, i, ds[i]): i for i in range(N)} for fut in tqdm(as_completed(fut2idx), total=N, desc="Florence+SAM"): idx = fut2idx[fut] status, *data = fut.result() if status == "ok": filled_col[idx], anno_col[idx], json_col[idx] = data else: fails += 1 FAILED_LOG.write_text(json.dumps({"idx": idx, "reason": data[0]})+"\n") print(f"❌ permanently failed rows: {fails}") keep = [i for i,x in enumerate(filled_col) if x] new_ds = ds.select(keep) new_ds = new_ds.add_column("bbox_filled", [filled_col[i] for i in keep]) new_ds = new_ds.add_column("annotated", [anno_col[i] for i in keep]) new_ds = new_ds.add_column("bbox_json", [json_col[i] for i in keep]) new_ds = new_ds.cast_column("bbox_filled", HFImage()) new_ds = new_ds.cast_column("annotated", HFImage()) print(f"✅ successes: {len(new_ds)} / {N}") print("Columns:", new_ds.column_names) print("QA artefacts →", QA_DIR.resolve()) # optional push new_ds.push_to_hub("fotographerai/surround_furniture_bboxfilled", private=True, max_shard_size="500MB")