Spaces:
Sleeping
Sleeping
| import os, io, json, logging | |
| from typing import List, Dict, Any | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import tensorflow as tf | |
| # optional gatekeep | |
| try: | |
| import cv2 | |
| HAS_OPENCV = True | |
| except Exception: | |
| HAS_OPENCV = False | |
| # HF Hub (สำหรับดึง derm-foundation) | |
| from huggingface_hub import snapshot_download | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("skinclassify") | |
| # ---------------------- Config ---------------------- | |
| DERM_MODEL_ID = os.getenv("DERM_MODEL_ID", "google/derm-foundation") | |
| DERM_LOCAL_DIR = os.getenv("DERM_LOCAL_DIR", "") # path ไปยัง SavedModel ถ้ามีออฟไลน์ | |
| HEAD_PATH = os.getenv("HEAD_PATH", "Models/mlp_best.keras") | |
| THRESHOLDS_PATH = os.getenv("THRESHOLDS_PATH", "Models/mlp_thresholds.npy") | |
| MU_PATH = os.getenv("MU_PATH", "Models/mu.npy") | |
| SD_PATH = os.getenv("SD_PATH", "Models/sd.npy") | |
| LABELS_PATH = os.getenv("LABELS_PATH", "Models/class_names.json") | |
| NPZ_PATH = os.getenv("NPZ_PATH", "") # ถ้าอยากโหลด mu/sd/class_names จากไฟล์เดียว | |
| TOPK = int(os.getenv("TOPK", "5")) | |
| # Gate keep params | |
| MIN_W, MIN_H = int(os.getenv("MIN_W", "128")), int(os.getenv("MIN_H", "128")) | |
| MIN_ASPECT, MAX_ASPECT = float(os.getenv("MIN_ASPECT", "0.5")), float(os.getenv("MAX_ASPECT", "2.0")) | |
| MIN_BRIGHT, MAX_BRIGHT = float(os.getenv("MIN_BRIGHT", "20")), float(os.getenv("MAX_BRIGHT", "235")) | |
| MIN_SKIN_RATIO = float(os.getenv("MIN_SKIN_RATIO", "0.15")) | |
| MIN_SHARPNESS = float(os.getenv("MIN_SHARPNESS", "30.0")) | |
| # Performance (กัน OOM บน free space) | |
| os.environ.setdefault("TF_NUM_INTRAOP_THREADS", "1") | |
| os.environ.setdefault("TF_NUM_INTEROP_THREADS", "1") | |
| os.environ.setdefault("OMP_NUM_THREADS", "1") | |
| os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") | |
| # Upload size (กัน DoS) | |
| MAX_UPLOAD = int(os.getenv("MAX_UPLOAD", str(6 * 1024 * 1024))) # 6MB | |
| DF_SIZE = (448, 448) | |
| app = FastAPI(title="SkinClassify API (Derm-Foundation)", version="2.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=os.getenv("ALLOW_ORIGINS", "*").split(","), | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------------------- Load labels ---------------------- | |
| def _load_json(path): | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| if os.path.exists(LABELS_PATH): | |
| CLASS_NAMES: List[str] = _load_json(LABELS_PATH) | |
| logger.info(f"Loaded class_names from {LABELS_PATH}") | |
| elif NPZ_PATH and os.path.exists(NPZ_PATH): | |
| arr = np.load(NPZ_PATH, allow_pickle=True) | |
| if "class_names" in arr: | |
| CLASS_NAMES = list(arr["class_names"]) | |
| logger.info(f"Loaded class_names from {NPZ_PATH}:class_names") | |
| else: | |
| raise RuntimeError("No LABELS_PATH and class_names not found in NPZ") | |
| else: | |
| raise RuntimeError("LABELS_PATH not found and NPZ_PATH not provided.") | |
| C = len(CLASS_NAMES) | |
| # ---------------------- Load head ---------------------- | |
| logger.info(f"Loading head from {HEAD_PATH}") | |
| head = tf.keras.models.load_model(HEAD_PATH, compile=False) | |
| # ---------------------- Load mu/sd ---------------------- | |
| def _load_mu_sd(): | |
| if os.path.exists(MU_PATH) and os.path.exists(SD_PATH): | |
| mu_ = np.load(MU_PATH).astype("float32") | |
| sd_ = np.load(SD_PATH).astype("float32") | |
| return mu_, sd_ | |
| if NPZ_PATH and os.path.exists(NPZ_PATH): | |
| arr = np.load(NPZ_PATH, allow_pickle=True) | |
| mu_ = arr["mu"].astype("float32") | |
| sd_ = arr["sd"].astype("float32") | |
| return mu_, sd_ | |
| raise RuntimeError("mu/sd not found (MU_PATH/SD_PATH or NPZ_PATH).") | |
| mu, sd = _load_mu_sd() | |
| logger.info("Loaded mu/sd") | |
| # ---------------------- Load thresholds ---------------------- | |
| if os.path.exists(THRESHOLDS_PATH): | |
| best_th = np.load(THRESHOLDS_PATH).astype("float32") | |
| if best_th.shape[0] != C: | |
| raise RuntimeError(f"thresholds size {best_th.shape[0]} != #classes {C}") | |
| else: | |
| logger.warning("THRESHOLDS_PATH not found -> default 0.5 for all classes") | |
| best_th = np.full(C, 0.5, dtype="float32") | |
| # ---------------------- Wrap head with standardization ---------------------- | |
| inp = tf.keras.Input(shape=(mu.shape[-1],), name="embedding") | |
| x = tf.keras.layers.Lambda(lambda e: (e - mu) / (sd + 1e-6), name="standardize")(inp) | |
| out = head(x) | |
| clf = tf.keras.Model(inp, out, name="head_with_norm") | |
| # ---------------------- Load derm-foundation ---------------------- | |
| # ใช้ snapshot_download + tf.saved_model.load (ถูกกับโมเดลของ Google) | |
| logger.info("Loading Derm Foundation (first time may take a while)...") | |
| try: | |
| if DERM_LOCAL_DIR and os.path.isdir(DERM_LOCAL_DIR) and os.path.exists(os.path.join(DERM_LOCAL_DIR, "saved_model.pb")): | |
| derm_dir = DERM_LOCAL_DIR | |
| logger.info(f"Loaded Derm Foundation from local: {DERM_LOCAL_DIR}") | |
| else: | |
| logger.info(f"Downloading derm-foundation from hub: {DERM_MODEL_ID}") | |
| derm_dir = snapshot_download( | |
| repo_id=DERM_MODEL_ID, | |
| repo_type="model", | |
| allow_patterns=["saved_model.pb", "variables/*"], | |
| ) | |
| logger.info(f"Derm Foundation downloaded to: {derm_dir}") | |
| derm = tf.saved_model.load(derm_dir) | |
| infer = derm.signatures["serving_default"] # call with key 'inputs' | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"Failed to load derm-foundation: {e}. " | |
| "Make sure you accepted the model terms on Hugging Face, " | |
| "or set DERM_LOCAL_DIR to a local SavedModel path." | |
| ) | |
| # ---------------------- Utils ---------------------- | |
| def pil_to_png_bytes_448(pil_img: Image.Image) -> bytes: | |
| pil_img = pil_img.convert("RGB").resize(DF_SIZE) | |
| arr = np.array(pil_img, dtype=np.uint8) | |
| return tf.io.encode_png(arr).numpy() | |
| def _brightness(np_img_rgb: np.ndarray) -> float: | |
| r,g,b = np_img_rgb[...,0], np_img_rgb[...,1], np_img_rgb[...,2] | |
| y = 0.2126*r + 0.7152*g + 0.0722*b | |
| return float(y.mean()) | |
| def _sharpness(np_img_rgb: np.ndarray) -> float: | |
| if not HAS_OPENCV: | |
| return 100.0 | |
| gray = cv2.cvtColor(np_img_rgb, cv2.COLOR_RGB2GRAY) | |
| return float(cv2.Laplacian(gray, cv2.CV_64F).var()) | |
| def _skin_ratio(np_img_rgb: np.ndarray) -> float: | |
| img = Image.fromarray(np_img_rgb).convert("YCbCr") | |
| ycbcr = np.array(img) | |
| Cb = ycbcr[...,1]; Cr = ycbcr[...,2] | |
| mask = (Cb >= 77) & (Cb <= 127) & (Cr >= 133) & (Cr <= 173) | |
| return float(mask.mean()) | |
| def gatekeep_image(img_bytes: bytes) -> Dict[str, Any]: | |
| try: | |
| img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| except Exception: | |
| return {"ok": False, "reasons": ["invalid_image"], "metrics": {}} | |
| w,h = img.size | |
| metrics = {"width": w, "height": h} | |
| reasons = [] | |
| if w < MIN_W or h < MIN_H: | |
| reasons.append("too_small") | |
| aspect = w / h | |
| metrics["aspect"] = float(aspect) | |
| if not (MIN_ASPECT <= aspect <= MAX_ASPECT): | |
| reasons.append("weird_aspect") | |
| np_img = np.array(img) | |
| bright = _brightness(np_img) | |
| metrics["brightness"] = bright | |
| if bright < MIN_BRIGHT: reasons.append("too_dark") | |
| if bright > MAX_BRIGHT: reasons.append("too_bright") | |
| if HAS_OPENCV: | |
| sharp = _sharpness(np_img) | |
| metrics["sharpness"] = sharp | |
| if sharp < MIN_SHARPNESS: reasons.append("too_blurry") | |
| ratio = _skin_ratio(np_img) | |
| metrics["skin_ratio"] = ratio | |
| if ratio < MIN_SKIN_RATIO: reasons.append("not_skin_like") | |
| return {"ok": len(reasons)==0, "reasons": reasons, "metrics": metrics} | |
| def predict_probs(img_bytes: bytes) -> np.ndarray: | |
| pil = Image.open(io.BytesIO(img_bytes)).convert("RGB").resize(DF_SIZE) | |
| by = pil_to_png_bytes_448(pil) | |
| ex = tf.train.Example(features=tf.train.Features( | |
| feature={'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[by]))} | |
| )).SerializeToString() | |
| out = infer(inputs=tf.constant([ex])) | |
| # ปกติคีย์จะชื่อ "embedding" | |
| if "embedding" not in out: | |
| raise RuntimeError(f"Unexpected derm-foundation outputs: {list(out.keys())}") | |
| emb = out["embedding"].numpy().astype("float32") # (1, 6144) | |
| probs = clf.predict(emb, verbose=0)[0] | |
| return probs | |
| # ---------------------- Endpoints ---------------------- | |
| def health(): | |
| return { | |
| "ok": True, | |
| "classes": len(CLASS_NAMES), | |
| "derm": DERM_MODEL_ID or DERM_LOCAL_DIR, | |
| "has_opencv": HAS_OPENCV | |
| } | |
| async def predict(request: Request, file: UploadFile = File(...)): | |
| # limit content-length | |
| cl = request.headers.get("content-length") | |
| if cl and int(cl) > MAX_UPLOAD: | |
| raise HTTPException(413, "File too large") | |
| img_bytes = await file.read() | |
| if len(img_bytes) > MAX_UPLOAD: | |
| raise HTTPException(413, "File too large") | |
| gate = gatekeep_image(img_bytes) | |
| if not gate["ok"]: | |
| return JSONResponse(status_code=200, content={"ok": False, "reason": "gate_reject", "gate": gate}) | |
| probs = predict_probs(img_bytes) | |
| order = np.argsort(probs)[::-1] | |
| top = [{"label": CLASS_NAMES[i], "prob": float(probs[i])} for i in order[:TOPK]] | |
| preds = (probs >= best_th).astype(np.int32) | |
| positives = [{"label": CLASS_NAMES[i], "prob": float(probs[i])} for i in range(C) if preds[i] == 1] | |
| return { | |
| "ok": True, | |
| "gate": gate, | |
| "result": { | |
| "type": "multilabel", | |
| "thresholds_used": {CLASS_NAMES[i]: float(best_th[i]) for i in range(C)}, | |
| "positives": positives, | |
| "topk": top, | |
| "probs": {CLASS_NAMES[i]: float(probs[i]) for i in range(C)} | |
| } | |
| } | |
| # สำหรับรันนอก Docker (เช่นทดสอบ local) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", "7860")) | |
| uvicorn.run(app, host="0.0.0.0", port=port, workers=1) | |