import contextlib, io, base64, torch, json, os, threading from PIL import Image import open_clip from huggingface_hub import hf_hub_download, create_commit, CommitOperationAdd from safetensors.torch import save_file, load_file from reparam import reparameterize_model ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") HF_LABEL_REPO = os.getenv("HF_LABEL_REPO", "") # e.g. "org/mobileclip-labels" HF_WRITE_TOKEN = os.getenv("HF_WRITE_TOKEN", "") HF_READ_TOKEN = os.getenv("HF_READ_TOKEN", HF_WRITE_TOKEN) def _fingerprint(device: str, dtype: torch.dtype) -> dict: return { "model_id": "MobileCLIP-B", "pretrained": "datacompdr", "open_clip": getattr(open_clip, "__version__", "unknown"), "torch": torch.__version__, "cuda": torch.version.cuda if torch.cuda.is_available() else None, "dtype_runtime": str(dtype), "text_norm": "L2", "logit_scale": 100.0, } class EndpointHandler: def __init__(self, path: str = ""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.dtype = torch.float16 if self.device == "cuda" else torch.float32 # 1) Load model + transforms model, _, self.preprocess = open_clip.create_model_and_transforms( "MobileCLIP-B", pretrained="datacompdr" ) model.eval() model = reparameterize_model(model) model.to(self.device) if self.device == "cuda": model = model.to(torch.float16) self.model = model self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B") self.fingerprint = _fingerprint(self.device, self.dtype) self._lock = threading.Lock() # 2) Try to load snapshot from Hub; else seed from items.json loaded = False if HF_LABEL_REPO: with contextlib.suppress(Exception): loaded = self._load_snapshot_from_hub_latest() if not loaded: items_path = "items.json" if not path else f"{path}/items.json" with open(items_path, "r", encoding="utf-8") as f: items = json.load(f) prompts = [it["prompt"] for it in items] self.class_ids = [int(it["id"]) for it in items] self.class_names = [it["name"] for it in items] with torch.no_grad(): toks = self.tokenizer(prompts).to(self.device) feats = self.model.encode_text(toks) feats = feats / feats.norm(dim=-1, keepdim=True) self.text_features_cpu = feats.detach().cpu().to(torch.float32).contiguous() self._to_device() self.labels_version = 1 def __call__(self, data): payload = data.get("inputs", data) # Admin op: upsert_labels op = payload.get("op") if op == "upsert_labels": if payload.get("token") != ADMIN_TOKEN: return {"error": "unauthorized"} items = payload.get("items", []) or [] added = self._upsert_items(items) if added > 0: new_ver = int(getattr(self, "labels_version", 1)) + 1 try: self._persist_snapshot_to_hub(new_ver) self.labels_version = new_ver except Exception as e: return {"status": "error", "added": added, "detail": str(e)} return {"status": "ok", "added": added, "labels_version": getattr(self, "labels_version", 1)} # Admin op: reload_labels if op == "reload_labels": if payload.get("token") != ADMIN_TOKEN: return {"error": "unauthorized"} try: ver = int(payload.get("version")) except Exception: return {"error": "invalid_version"} ok = self._load_snapshot_from_hub_version(ver) return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)} # Admin op: remove_labels if op == "remove_labels": if payload.get("token") != ADMIN_TOKEN: return {"error": "unauthorized"} ids_to_remove = set(payload.get("ids", [])) if not ids_to_remove: return {"error": "no_ids_provided"} removed = self._remove_items(ids_to_remove) if removed > 0: new_ver = int(getattr(self, "labels_version", 1)) + 1 try: self._persist_snapshot_to_hub(new_ver) self.labels_version = new_ver except Exception as e: return {"status": "error", "removed": removed, "detail": str(e)} return {"status": "ok", "removed": removed, "labels_version": getattr(self, "labels_version", 1)} # Freshness guard (optional) min_ver = payload.get("min_labels_version") if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0): with contextlib.suppress(Exception): self._load_snapshot_from_hub_version(min_ver) # Classification path (unchanged contract) img_b64 = payload["image"] image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) if self.device == "cuda": img_tensor = img_tensor.to(torch.float16) with torch.no_grad(): img_feat = self.model.encode_image(img_tensor) img_feat /= img_feat.norm(dim=-1, keepdim=True) probs = (100.0 * img_feat @ self.text_features.T).softmax(dim=-1)[0] results = zip(self.class_ids, self.class_names, probs.detach().cpu().tolist()) top_k = int(payload.get("top_k", len(self.class_ids))) return sorted( [{"id": i, "label": name, "score": float(p)} for i, name, p in results], key=lambda x: x["score"], reverse=True, )[:top_k] # ------------- helpers ------------- def _encode_text(self, prompts): with torch.no_grad(): toks = self.tokenizer(prompts).to(self.device) feats = self.model.encode_text(toks) feats = feats / feats.norm(dim=-1, keepdim=True) return feats def _to_device(self): self.text_features = self.text_features_cpu.to( self.device, dtype=(torch.float16 if self.device == "cuda" else torch.float32) ) def _upsert_items(self, new_items): if not new_items: return 0 with self._lock: # Get ALL existing IDs and names from current state known_ids = set(getattr(self, "class_ids", [])) # Create lowercase set for case-insensitive comparison known_names_lower = set(name.lower() for name in getattr(self, "class_names", [])) # Filter items, checking against both ID and name (case-insensitive) batch = [] for it in new_items: item_id = int(it.get("id")) item_name = it.get("name") # Skip if either ID or name already exists (case-insensitive for names) if item_id in known_ids: continue # Skip duplicate ID elif item_name.lower() in known_names_lower: continue # Skip duplicate name (case-insensitive) else: batch.append(it) if not batch: return 0 # Process the filtered batch prompts = [it["prompt"] for it in batch] feats = self._encode_text(prompts).detach().cpu().to(torch.float32) # Update the persistent state if not hasattr(self, "text_features_cpu"): self.text_features_cpu = feats.contiguous() self.class_ids = [int(it["id"]) for it in batch] self.class_names = [it["name"] for it in batch] else: self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous() self.class_ids.extend([int(it["id"]) for it in batch]) self.class_names.extend([it["name"] for it in batch]) self._to_device() return len(batch) def _remove_items(self, ids_to_remove): if not ids_to_remove or not hasattr(self, "class_ids"): return 0 with self._lock: ids_to_remove = set(int(id) for id in ids_to_remove) # Find indices to keep indices_to_keep = [] removed_count = 0 for i, class_id in enumerate(self.class_ids): if class_id not in ids_to_remove: indices_to_keep.append(i) else: removed_count += 1 if removed_count == 0: return 0 # Filter the tensors and lists if indices_to_keep: self.text_features_cpu = self.text_features_cpu[indices_to_keep].contiguous() self.class_ids = [self.class_ids[i] for i in indices_to_keep] self.class_names = [self.class_names[i] for i in indices_to_keep] else: # All items removed, reset to empty self.text_features_cpu = torch.empty(0, self.text_features_cpu.shape[1]) self.class_ids = [] self.class_names = [] self._to_device() return removed_count def _persist_snapshot_to_hub(self, version: int): if not HF_LABEL_REPO: raise RuntimeError("HF_LABEL_REPO not set") if not HF_WRITE_TOKEN: raise RuntimeError("HF_WRITE_TOKEN not set for publishing") emb_path = "/tmp/embeddings.safetensors" meta_path = "/tmp/meta.json" latest_bytes = io.BytesIO(json.dumps({"version": int(version)}).encode("utf-8")) save_file({"embeddings": self.text_features_cpu.to(torch.float32)}, emb_path) meta = { "items": [{"id": int(i), "name": n} for i, n in zip(self.class_ids, self.class_names)], "fingerprint": self.fingerprint, "dims": int(self.text_features_cpu.shape[1]), "count": int(self.text_features_cpu.shape[0]), "version": int(version), } with open(meta_path, "w", encoding="utf-8") as f: json.dump(meta, f) ops = [ CommitOperationAdd( path_in_repo=f"snapshots/v{version}/embeddings.safetensors", path_or_fileobj=emb_path ), CommitOperationAdd( path_in_repo=f"snapshots/v{version}/meta.json", path_or_fileobj=meta_path ), CommitOperationAdd( path_in_repo="snapshots/latest.json", path_or_fileobj=latest_bytes ), ] create_commit( repo_id=HF_LABEL_REPO, repo_type="dataset", operations=ops, token=HF_WRITE_TOKEN, commit_message=f"labels v{version}", ) def _load_snapshot_from_hub_version(self, version: int) -> bool: if not HF_LABEL_REPO: return False with self._lock: emb_p = hf_hub_download( HF_LABEL_REPO, f"snapshots/v{version}/embeddings.safetensors", repo_type="dataset", token=HF_READ_TOKEN, force_download=True, ) meta_p = hf_hub_download( HF_LABEL_REPO, f"snapshots/v{version}/meta.json", repo_type="dataset", token=HF_READ_TOKEN, force_download=True, ) meta = json.load(open(meta_p, "r", encoding="utf-8")) if meta.get("fingerprint") != self.fingerprint: raise RuntimeError("Embedding/model fingerprint mismatch") feats = load_file(emb_p)["embeddings"] # float32 CPU self.text_features_cpu = feats.contiguous() self.class_ids = [int(x["id"]) for x in meta.get("items", [])] self.class_names = [x["name"] for x in meta.get("items", [])] self.labels_version = int(meta.get("version", version)) self._to_device() return True def _load_snapshot_from_hub_latest(self) -> bool: if not HF_LABEL_REPO: return False try: latest_p = hf_hub_download( HF_LABEL_REPO, "snapshots/latest.json", repo_type="dataset", token=HF_READ_TOKEN, ) except Exception: return False latest = json.load(open(latest_p, "r", encoding="utf-8")) ver = int(latest.get("version", 0)) if ver <= 0: return False return self._load_snapshot_from_hub_version(ver) # """ # MobileCLIP‑B Zero‑Shot Image Classifier (Hugging Face Inference Endpoint) # =========================================================================== # * One container instance is created per replica; the `EndpointHandler` # object below is instantiated exactly **once** at start‑up. # * At request time (`__call__`) we receive a base‑64‑encoded image, run a # **single forward pass**, and return class probabilities. # Design choices # -------------- # 1. **Model & transform come from OpenCLIP** # This guarantees we apply **identical preprocessing** to what the model # was trained with (224 × 224 crop + mean/std normalisation). # 2. **Re‑parameterisation for inference** # MobileCLIP uses MobileOne blocks that have extra convolution branches # for training; `reparameterize_model` fuses them so inference is fast # and deterministic. # 3. **Text embeddings are cached** # The class “prompts” (e.g. `"a photo of a cat"`) are encoded **once at # start‑up**. Each request therefore encodes *only* the image and # performs a single matrix multiplication. # 4. **Mixed precision on GPU** # If the container has CUDA, we cast the model **and** inputs to # `float16`. That halves memory and roughly doubles throughput on most # modern GPUs. On CPU we stay in `float32` for numerical stability. # """ # import contextlib, io, base64, json # from pathlib import Path # from typing import Any, Dict, List # import torch # from PIL import Image # import open_clip # from reparam import reparameterize_model # local copy (~60 LoC) of Apple’s helper # class EndpointHandler: # """ # Hugging Face entry‑point. The toolkit will instantiate this class # once and call it for every HTTP request. # Parameters # ---------- # path : str, optional # Root directory of the repository. HF mounts the code under # `/repository`; we use this path to locate `items.json`. # """ # # ------------------------------------------------------------------ # # # INITIALISATION (runs **once**) # # # ------------------------------------------------------------------ # # def __init__(self, path: str = "") -> None: # self.device = "cuda" if torch.cuda.is_available() else "cpu" # # 1️⃣ Load MobileCLIP‑B weights & transforms ------------------- # # `pretrained="datacompdr"` makes OpenCLIP download the # # official checkpoint from the Hub (cached in the image layer). # model, _, self.preprocess = open_clip.create_model_and_transforms( # "MobileCLIP-B", pretrained="datacompdr" # ) # model.eval() # disable dropout / BN updates # model = reparameterize_model(model) # fuse MobileOne branches # model.to(self.device) # if self.device == "cuda": # model = model.to(torch.float16) # FP16 for throughput # self.model = model # hold a reference # # 2️⃣ Build the tokenizer once -------------------------------- # tokenizer = open_clip.get_tokenizer("MobileCLIP-B") # # 3️⃣ Load class metadata ------------------------------------- # # Expect JSON file: [{"id": 3, "name": "cat", "prompt": "cat"}, …] # items_path = Path(path) / "items.json" # with items_path.open("r", encoding="utf-8") as f: # class_defs: List[Dict[str, Any]] = json.load(f) # # Extract the bits we need later # prompts = [item["prompt"] for item in class_defs] # self.class_ids: List[int] = [item["id"] for item in class_defs] # self.class_names: List[str] = [item["name"] for item in class_defs] # # 4️⃣ Encode all prompts once --------------------------------- # with torch.no_grad(): # text_tokens = tokenizer(prompts).to(self.device) # text_feats = self.model.encode_text(text_tokens) # text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) # self.text_features = text_feats # [num_classes, 512] # # ------------------------------------------------------------------ # # # INFERENCE CALL # # # ------------------------------------------------------------------ # # def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # """ # Parameters # ---------- # data : dict # Either the raw payload `{"image": ""}` **or** the # Hugging Face convention `{"inputs": {...}}`. # Returns # ------- # list of dict # Sorted list of `{"id": int, "label": str, "score": float}`. # Scores are the softmax probabilities over the *provided* # class list (they sum to 1.0). # """ # # 1️⃣ Unpack the request payload ------------------------------ # payload: Dict[str, Any] = data.get("inputs", data) # img_b64: str = payload["image"] # # 2️⃣ Decode + preprocess ------------------------------------- # image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") # img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # [1, 3, 224, 224] # if self.device == "cuda": # img_tensor = img_tensor.to(torch.float16) # # 3️⃣ Forward pass (image only) ------------------------------- # with torch.no_grad(): # no autograd graph # img_feat = self.model.encode_image(img_tensor) # [1, 512] # img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # L2‑normalise # # cosine similarity → logits → softmax probabilities # probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0] # [num_classes] # # 4️⃣ Assemble JSON‑serialisable response --------------------- # results = zip(self.class_ids, self.class_names, probs.cpu().tolist()) # return sorted( # [{"id": cid, "label": name, "score": float(p)} for cid, name, p in results], # key=lambda x: x["score"], # reverse=True, # )