Hot-or-Not / core.py
sayantan47's picture
refactor
c72ead4
import os
import sys
import traceback
import numpy as np
import onnxruntime as ort
from transformers import CLIPProcessor
from PIL import Image
from typing import Optional, List, Tuple, Union
from abc import ABC, abstractmethod
# ============================================================
# Configuration
# ============================================================
class Config:
DEFAULT_OUTPUT = (0.0, 0.0, 0.0, 0.0, "unknown", "unknown")
FIXED_IMG_W = 300
FIXED_IMG_H = 300
PROVIDERS = ["CPUExecutionProvider"] # keep CPU to avoid CUDA DLL issues
# ============================================================
# Utilities
# ============================================================
def print_exc(prefix: str):
"""Print exception with prefix to stderr."""
print(prefix, file=sys.stderr)
traceback.print_exc()
def softmax_safe(x: np.ndarray, axis: int = -1) -> np.ndarray:
"""Safe softmax implementation that handles edge cases."""
try:
x = x - np.max(x, axis=axis, keepdims=True)
ex = np.exp(x)
denom = np.sum(ex, axis=axis, keepdims=True)
denom = np.where(denom == 0, 1.0, denom)
return ex / denom
except Exception:
print_exc("[softmax_safe] failed")
return np.ones_like(x) / x.shape[-1]
def ensure_int64(feed_dict: dict) -> dict:
"""Convert int32 arrays to int64 for ONNX compatibility."""
out = {}
for k, v in feed_dict.items():
if isinstance(v, np.ndarray) and v.dtype == np.int32:
out[k] = v.astype(np.int64)
else:
out[k] = v
return out
def create_dummy_image(width: int = Config.FIXED_IMG_W, height: int = Config.FIXED_IMG_H) -> Image.Image:
"""Create a dummy gray image for fallback cases."""
return Image.fromarray(np.full((height, width, 3), 127, dtype=np.uint8), "RGB")
# ============================================================
# Abstract Model Interface
# ============================================================
class ModelInterface(ABC):
"""Abstract interface for CLIP models."""
@abstractmethod
def is_loaded(self) -> bool:
"""Check if model is properly loaded."""
pass
@abstractmethod
def run_inference(self, image_pil: Image.Image, texts: List[str]) -> Optional[np.ndarray]:
"""Run CLIP inference on image and texts."""
pass
# ============================================================
# Model Implementations
# ============================================================
class HuggingFaceModel(ModelInterface):
"""CLIP model loaded from Hugging Face Hub."""
def __init__(self, repo_id: str, model_filename: str):
self.repo_id = repo_id
self.model_filename = model_filename
self.processor = None
self.session = None
self._load_model()
def _load_model(self):
"""Load model and processor from Hugging Face Hub."""
try:
from huggingface_hub import hf_hub_download
# Download model.onnx
model_path = hf_hub_download(
repo_id=self.repo_id,
filename=self.model_filename,
local_dir="hf_cache",
local_dir_use_symlinks=False,
resume_download=True,
)
# Load processor (tokenizer + preproc files) from the same repo
self.processor = CLIPProcessor.from_pretrained(self.repo_id)
self.session = ort.InferenceSession(model_path, providers=Config.PROVIDERS)
except Exception:
print_exc("[HuggingFaceModel] Failed to download/load model from HF Hub.")
self.processor, self.session = None, None
def is_loaded(self) -> bool:
"""Check if model is properly loaded."""
return self.processor is not None and self.session is not None
def run_inference(self, image_pil: Image.Image, texts: List[str]) -> Optional[np.ndarray]:
"""Run CLIP inference on image and texts."""
if not self.is_loaded():
return None
try:
inputs = self.processor(
text=texts, images=image_pil, return_tensors="np", padding=True
)
ort_inputs = ensure_int64(inputs)
outputs = self.session.run(None, ort_inputs)
logits_per_image = outputs[0] # (1, n_texts)
probs = softmax_safe(logits_per_image, axis=-1)[0]
return probs
except Exception:
print_exc("[HuggingFaceModel] Inference failed")
return None
class LocalModel(ModelInterface):
"""CLIP model loaded from local files."""
def __init__(self, model_path: str, processor_path: Optional[str] = None):
self.model_path = model_path
self.processor_path = processor_path
self.processor = None
self.session = None
self._load_model()
def _load_model(self):
"""Load model and processor from local files."""
try:
# Load ONNX model
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model file not found: {self.model_path}")
self.session = ort.InferenceSession(self.model_path, providers=Config.PROVIDERS)
# Load processor
if self.processor_path and os.path.exists(self.processor_path):
self.processor = CLIPProcessor.from_pretrained(self.processor_path)
else:
# Fallback to a default processor if local processor not available
print("[LocalModel] Using default CLIP processor")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
except Exception:
print_exc("[LocalModel] Failed to load local model.")
self.processor, self.session = None, None
def is_loaded(self) -> bool:
"""Check if model is properly loaded."""
return self.processor is not None and self.session is not None
def run_inference(self, image_pil: Image.Image, texts: List[str]) -> Optional[np.ndarray]:
"""Run CLIP inference on image and texts."""
if not self.is_loaded():
return None
try:
inputs = self.processor(
text=texts, images=image_pil, return_tensors="np", padding=True
)
ort_inputs = ensure_int64(inputs)
outputs = self.session.run(None, ort_inputs)
logits_per_image = outputs[0] # (1, n_texts)
probs = softmax_safe(logits_per_image, axis=-1)[0]
return probs
except Exception:
print_exc("[LocalModel] Inference failed")
return None
# ============================================================
# Core Scoring Logic
# ============================================================
class HotOrNotScorer:
"""Core logic for hot-or-not scoring using CLIP models."""
def __init__(self, model: ModelInterface):
self.model = model
def _run_clip(self, image_pil: Image.Image, texts: List[str]) -> Optional[np.ndarray]:
"""Run CLIP inference wrapper."""
return self.model.run_inference(image_pil, texts)
def detect_gender(self, image_pil: Image.Image) -> str:
"""Detect gender from image."""
texts = ["a man", "a woman"]
probs = self._run_clip(image_pil, texts)
if probs is None:
return "unknown"
return "man" if int(np.argmax(probs)) == 0 else "woman"
def detect_age_group(self, image_pil: Image.Image) -> str:
"""Detect age group from image."""
texts = ["a young person", "a middle-aged person", "an old person"]
probs = self._run_clip(image_pil, texts)
if probs is None:
return "unknown"
return ["young", "middle-aged", "old"][int(np.argmax(probs))]
def score_with_terms(self, image_pil: Image.Image, positive_terms: List[str], negative_terms: List[str]) -> Tuple[float, float, float, float]:
"""Score image with positive and negative terms."""
probs_all = []
for pos, neg in zip(positive_terms, negative_terms):
probs = self._run_clip(image_pil, [pos, neg])
if probs is None or len(probs) != 2:
return (
Config.DEFAULT_OUTPUT[0],
Config.DEFAULT_OUTPUT[1],
Config.DEFAULT_OUTPUT[2],
Config.DEFAULT_OUTPUT[3],
)
probs_all.append(probs)
s1 = round((probs_all[0][0] - probs_all[0][1] + 1) * 50, 2)
s2 = round((probs_all[1][0] - probs_all[1][1] + 1) * 50, 2)
s3 = round((probs_all[2][0] - probs_all[2][1] + 1) * 50, 2)
positive_probs = [p[0] for p in probs_all]
negative_probs = [p[1] for p in probs_all]
hot_score = float(np.mean(positive_probs))
ugly_score = float(np.mean(negative_probs))
composite = round(((hot_score - ugly_score) + 1) * 50, 2)
return composite, s1, s2, s3
def evaluate_image(self, image: Union[np.ndarray, Image.Image, None]) -> Tuple[float, float, float, float, str, str]:
"""Main evaluation function that returns complete scoring."""
if not self.model.is_loaded():
return Config.DEFAULT_OUTPUT
# Handle input image
if image is None:
image_pil = create_dummy_image()
else:
try:
if isinstance(image, np.ndarray):
image_pil = Image.fromarray(image.astype("uint8"), "RGB")
elif isinstance(image, Image.Image):
image_pil = image
else:
raise ValueError("Unsupported image type")
except Exception:
print_exc("[evaluate_image] Failed to convert input to PIL. Using dummy image.")
image_pil = create_dummy_image()
try:
# Detect attributes
gender = self.detect_gender(image_pil)
age_group = self.detect_age_group(image_pil)
# Define terms based on detected gender
if gender == "man":
positive_terms = ["a handsome man", "a charming man", "an attractive man"]
negative_terms = ["an ugly man", "a gross man", "a hideous man"]
elif gender == "woman":
positive_terms = [
"a beautiful woman",
"a cute woman",
"an attractive woman",
]
negative_terms = ["an ugly woman", "a gross woman", "a hideous woman"]
else:
positive_terms = [
"a hot person",
"a beautiful person",
"an attractive person",
]
negative_terms = ["an ugly person", "a gross person", "a hideous person"]
# Calculate scores
composite, hotness, second, attractiveness = self.score_with_terms(
image_pil, positive_terms, negative_terms
)
return composite, hotness, second, attractiveness, gender, age_group
except Exception:
print_exc("[evaluate_image] Unexpected error")
return Config.DEFAULT_OUTPUT
# ============================================================
# Factory Functions
# ============================================================
def create_huggingface_scorer(repo_id: str = "sayantan47/clip-vit-b32-onnx", model_filename: str = "onnx/model.onnx") -> HotOrNotScorer:
"""Create a scorer using HuggingFace model."""
model = HuggingFaceModel(repo_id, model_filename)
return HotOrNotScorer(model)
def create_local_scorer(model_path: str, processor_path: Optional[str] = None) -> HotOrNotScorer:
"""Create a scorer using local model."""
model = LocalModel(model_path, processor_path)
return HotOrNotScorer(model)