salso's picture
Upload 28 files
545e508 verified
#!/usr/bin/env python
"""
reassemble_bbox_dataset_resume.py
---------------------------------
Incrementally rebuilds `bbox_filled / annotated / bbox_json` columns from
QA artefacts and pushes the final dataset **privately** to HF Hub.
β€’ Safe to ^C / rerun (uses on-disk Arrow cache)
β€’ When NOTHING is left to process it *just* loads the cache and pushes.
β€’ Uses path-only image columns (HFImage(decode=False)) to keep RAM tiny.
"""
import os, json
from pathlib import Path
from tqdm.auto import tqdm
from datasets import (
load_dataset, load_from_disk, Dataset, disable_progress_bar, Features,
Value, Image as HFImage
)
from PIL import Image
from huggingface_hub.utils import HfHubHTTPError
disable_progress_bar()
# ══════ CONFIG ══════════════════════════════════════════════════════
DATASET_NAME = "fotographerai/furniture_captioned_segment_prompt"
SPLIT = "train"
QA_DIR = Path("bbox_review_recaptioned") # artefacts
CACHE_DIR = Path("rebuild_cache") # incremental Arrow cache
CACHE_DIR.mkdir(exist_ok=True)
TARGET_SIDE = 1500
GREEN_RGB = (0, 255, 0)
BATCH_SAVE = 500
HUB_REPO = "fotographerai/furniture_bboxfilled_rebuild"
HF_TOKEN = os.environ.get("HF_TOKEN", "").strip() # needs write+private
# ══════ HELPERS ═════════════════════════════════════════════════════
def img_ref(p: Path) -> dict: # path-only image dict
return {"path": str(p), "bytes": None}
def make_green_png(p: Path):
if not p.exists():
Image.new("RGB", (TARGET_SIDE, TARGET_SIDE), GREEN_RGB).save(p)
def ensure_full_bbox(p: Path):
if not p.exists():
p.write_text(json.dumps({"xyxy": [[0, 0, TARGET_SIDE, TARGET_SIDE]]}))
# ══════ LOAD SOURCE DATASET ═════════════════════════════════════════
base_ds = load_dataset(DATASET_NAME, split=SPLIT, streaming=False)
N_TOTAL = len(base_ds)
print("Original rows:", N_TOTAL)
# ══════ LOAD OR INIT CACHE ══════════════════════════════════════════
if (CACHE_DIR / "dataset_info.json").exists():
cache_ds = load_from_disk(CACHE_DIR)
done = set(cache_ds["__row_idx__"])
print(f"Cache found β†’ {len(done)} rows already processed.")
records = {k: list(v) for k, v in cache_ds.to_dict().items()}
else:
done, records = set(), {"__row_idx__": [], "bbox_filled": [],
"annotated": [], "bbox_json": []}
missing = [i for i in range(N_TOTAL) if i not in done]
print("Rows still to process:", len(missing))
# ══════ NO WORK LEFT? push & exit ══════════════════════════════════
if not missing:
print("πŸ’€ nothing new to process – pushing cached dataset…")
try:
url = cache_ds.push_to_hub(
HUB_REPO, private=True, token=HF_TOKEN, max_shard_size="500MB"
)
print("πŸš€ dataset pushed to:", url)
except HfHubHTTPError as e:
print("❌ push failed:", e)
exit(0)
# ══════ PROCESS MISSING ROWS ═══════════════════════════════════════
for n, i in enumerate(tqdm(missing, desc="Re-assembling")):
g_png = QA_DIR / f"{i:06d}_green.png"
a_png = QA_DIR / f"{i:06d}_anno.png"
bbox_j = QA_DIR / f"{i:06d}_bbox.json"
if not (g_png.exists() and a_png.exists() and bbox_j.exists()):
mask_png = QA_DIR / f"{i:06d}_mask.png"
make_green_png(mask_png)
g_png = a_png = mask_png
ensure_full_bbox(bbox_j)
row = base_ds[i] # copy original cols once
records["__row_idx__"].append(i)
for k, v in row.items():
records.setdefault(k, []).append(v)
records["bbox_filled"].append(img_ref(g_png))
records["annotated"].append(img_ref(a_png))
records["bbox_json"].append(bbox_j.read_text())
if (n + 1) % BATCH_SAVE == 0:
Dataset.from_dict(records).save_to_disk(CACHE_DIR)
print(f"⏫ cached at {n+1}/{len(missing)}")
# ══════ FINAL DATASET FEATURES & SAVE ═══════════════════════════════
features = Features({
"__row_idx__" : Value("int32"),
"bbox_filled" : HFImage(decode=False),
"annotated" : HFImage(decode=False),
"bbox_json" : Value("string"),
# original columns inferred below
})
for k in base_ds.features:
if k not in features:
features[k] = base_ds.features[k]
final_ds = Dataset.from_dict(records, features=features)
final_ds.save_to_disk(CACHE_DIR)
print("βœ… cached dataset saved to", CACHE_DIR.resolve())
# ══════ PUSH PRIVATE ═══════════════════════════════════════════════
if not HF_TOKEN:
print("⚠️ HF_TOKEN env-var not set – skipping push.")
else:
try:
url = final_ds.push_to_hub(
HUB_REPO, private=True, token=HF_TOKEN, max_shard_size="500MB"
)
print("πŸš€ dataset pushed to:", url)
except HfHubHTTPError as e:
print("❌ push failed:", e)