Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import traceback | |
| import numpy as np | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| from transformers import CLIPProcessor | |
| from PIL import Image | |
| import gradio as gr | |
| # ============================================================ | |
| # Config | |
| # ============================================================ | |
| REPO_ID = "sayantan47/clip-vit-b32-onnx" # <-- change this | |
| MODEL_FILENAME = "onnx/model_q4.onnx" | |
| PROVIDERS = ["CPUExecutionProvider"] # keep CPU to avoid CUDA DLL issues | |
| DEFAULT_OUTPUT = (0.0, 0.0, 0.0, 0.0, "unknown", "unknown") | |
| FIXED_IMG_W = 300 | |
| FIXED_IMG_H = 300 | |
| # ============================================================ | |
| # Utils | |
| # ============================================================ | |
| def _print_exc(prefix: str): | |
| print(prefix, file=sys.stderr) | |
| traceback.print_exc() | |
| def _softmax_safe(x: np.ndarray, axis: int = -1) -> np.ndarray: | |
| try: | |
| x = x - np.max(x, axis=axis, keepdims=True) | |
| ex = np.exp(x) | |
| denom = np.sum(ex, axis=axis, keepdims=True) | |
| denom = np.where(denom == 0, 1.0, denom) | |
| return ex / denom | |
| except Exception: | |
| _print_exc("[_softmax_safe] failed") | |
| return np.ones_like(x) / x.shape[-1] | |
| def _ensure_int64(feed_dict): | |
| out = {} | |
| for k, v in feed_dict.items(): | |
| if isinstance(v, np.ndarray) and v.dtype == np.int32: | |
| out[k] = v.astype(np.int64) | |
| else: | |
| out[k] = v | |
| return out | |
| def _dummy_image(width=FIXED_IMG_W, height=FIXED_IMG_H): | |
| return Image.fromarray(np.full((height, width, 3), 127, dtype=np.uint8), "RGB") | |
| # ============================================================ | |
| # Load from HF Hub | |
| # ============================================================ | |
| def load_from_hub(): | |
| # download model.onnx | |
| model_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=MODEL_FILENAME, | |
| local_dir="hf_cache", | |
| local_dir_use_symlinks=False, | |
| resume_download=True, | |
| ) | |
| # load processor (tokenizer + preproc files) from the same repo | |
| proc = CLIPProcessor.from_pretrained(REPO_ID) | |
| sess = ort.InferenceSession(model_path, providers=PROVIDERS) | |
| return proc, sess | |
| try: | |
| processor, session = load_from_hub() | |
| except Exception: | |
| _print_exc("[GLOBAL INIT] Failed to download/load model from HF Hub.") | |
| processor, session = None, None | |
| # ============================================================ | |
| # Core helpers | |
| # ============================================================ | |
| def _run_clip(image_pil: Image.Image, texts): | |
| if processor is None or session is None: | |
| return None | |
| try: | |
| inputs = processor( | |
| text=texts, images=image_pil, return_tensors="np", padding=True | |
| ) | |
| ort_inputs = _ensure_int64(inputs) | |
| outputs = session.run(None, ort_inputs) | |
| logits_per_image = outputs[0] # (1, n_texts) | |
| probs = _softmax_safe(logits_per_image, axis=-1)[0] | |
| return probs | |
| except Exception: | |
| _print_exc("[_run_clip] Inference failed") | |
| return None | |
| def detect_gender(image_pil: Image.Image) -> str: | |
| texts = ["a man", "a woman"] | |
| probs = _run_clip(image_pil, texts) | |
| if probs is None: | |
| return "unknown" | |
| return "man" if int(np.argmax(probs)) == 0 else "woman" | |
| def detect_age_group(image_pil: Image.Image) -> str: | |
| texts = ["a young person", "a middle-aged person", "an old person"] | |
| probs = _run_clip(image_pil, texts) | |
| if probs is None: | |
| return "unknown" | |
| return ["young", "middle-aged", "old"][int(np.argmax(probs))] | |
| def score_with_terms(image_pil: Image.Image, positive_terms, negative_terms): | |
| probs_all = [] | |
| for pos, neg in zip(positive_terms, negative_terms): | |
| probs = _run_clip(image_pil, [pos, neg]) | |
| if probs is None or len(probs) != 2: | |
| return ( | |
| DEFAULT_OUTPUT[0], | |
| DEFAULT_OUTPUT[1], | |
| DEFAULT_OUTPUT[2], | |
| DEFAULT_OUTPUT[3], | |
| ) | |
| probs_all.append(probs) | |
| positive_probs = [p[0] for p in probs_all] | |
| negative_probs = [p[1] for p in probs_all] | |
| s1 = round((probs_all[0][0] - probs_all[0][1] + 1) * 50, 2) | |
| s2 = round((probs_all[1][0] - probs_all[1][1] + 1) * 50, 2) | |
| s3 = round((probs_all[2][0] - probs_all[2][1] + 1) * 50, 2) | |
| hot_score = float(np.mean(positive_probs)) | |
| ugly_score = float(np.mean(negative_probs)) | |
| composite = round(((hot_score - ugly_score) + 1) * 50, 2) | |
| return composite, s1, s2, s3 | |
| # ============================================================ | |
| # Gradio callback | |
| # ============================================================ | |
| def hotornot(image): | |
| if processor is None or session is None: | |
| return DEFAULT_OUTPUT | |
| if image is None: | |
| image_pil = _dummy_image() | |
| else: | |
| try: | |
| image_pil = Image.fromarray(image.astype("uint8"), "RGB") | |
| except Exception: | |
| _print_exc("[hotornot] Failed to convert input to PIL. Using dummy image.") | |
| image_pil = _dummy_image() | |
| try: | |
| gender = detect_gender(image_pil) | |
| age_group = detect_age_group(image_pil) | |
| if gender == "man": | |
| positive_terms = ["a handsome man", "a charming man", "an attractive man"] | |
| negative_terms = ["an ugly man", "a gross man", "a hideous man"] | |
| elif gender == "woman": | |
| positive_terms = [ | |
| "a beautiful woman", | |
| "a cute woman", | |
| "an attractive woman", | |
| ] | |
| negative_terms = ["an ugly woman", "a gross woman", "a hideous woman"] | |
| else: | |
| positive_terms = [ | |
| "a hot person", | |
| "a beautiful person", | |
| "an attractive person", | |
| ] | |
| negative_terms = ["an ugly person", "a gross person", "a hideous person"] | |
| composite, hotness, second, attractiveness = score_with_terms( | |
| image_pil, positive_terms, negative_terms | |
| ) | |
| return composite, hotness, second, attractiveness, gender, age_group | |
| except Exception: | |
| _print_exc("[hotornot] Unexpected error") | |
| return DEFAULT_OUTPUT | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| CSS = f""" | |
| #fixed_img_component img, | |
| #fixed_img_component canvas {{ | |
| width: {FIXED_IMG_W}px !important; | |
| height: {FIXED_IMG_H}px !important; | |
| object-fit: contain !important; | |
| }} | |
| """ | |
| with gr.Blocks(css=CSS) as demo: | |
| gr.Markdown("# Hot or Not (CLIP ONNX from Hugging Face Hub)") | |
| gr.Markdown( | |
| "Loads ONNX + tokenizer from HF Hub, runs on CPU, auto-detects gender & age, and scores appearance." | |
| ) | |
| with gr.Row(): | |
| image_in = gr.Image( | |
| label="Upload Image", | |
| type="numpy", | |
| image_mode="RGB", | |
| height=FIXED_IMG_H, | |
| width=FIXED_IMG_W, | |
| elem_id="fixed_img_component", | |
| ) | |
| with gr.Row(): | |
| out_total = gr.Textbox(label="Total Hot or Not™ Score") | |
| out_hot = gr.Textbox(label="Hotness Score") | |
| out_mid = gr.Textbox(label="Charm / Cuteness Score") | |
| out_attr = gr.Textbox(label="Attractiveness Score") | |
| out_gender = gr.Textbox(label="Predicted Gender") | |
| out_age = gr.Textbox(label="Predicted Age Group") | |
| run_btn = gr.Button("Rate") | |
| run_btn.click( | |
| fn=hotornot, | |
| inputs=[image_in], | |
| outputs=[out_total, out_hot, out_mid, out_attr, out_gender, out_age], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |