sayantan47 commited on
Commit
da6ac83
·
verified ·
1 Parent(s): aac0cdc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from huggingface_hub import hf_hub_download
7
+ from transformers import CLIPProcessor
8
+ from PIL import Image
9
+ import gradio as gr
10
+
11
+ # ============================================================
12
+ # Config
13
+ # ============================================================
14
+ REPO_ID = "sayantan47/clip-vit-b32-onnx" # <-- change this
15
+ MODEL_FILENAME = "onnx/model_q4.onnx"
16
+ PROVIDERS = ["CPUExecutionProvider"] # keep CPU to avoid CUDA DLL issues
17
+ DEFAULT_OUTPUT = (0.0, 0.0, 0.0, 0.0, "unknown", "unknown")
18
+ FIXED_IMG_W = 300
19
+ FIXED_IMG_H = 300
20
+
21
+
22
+ # ============================================================
23
+ # Utils
24
+ # ============================================================
25
+ def _print_exc(prefix: str):
26
+ print(prefix, file=sys.stderr)
27
+ traceback.print_exc()
28
+
29
+
30
+ def _softmax_safe(x: np.ndarray, axis: int = -1) -> np.ndarray:
31
+ try:
32
+ x = x - np.max(x, axis=axis, keepdims=True)
33
+ ex = np.exp(x)
34
+ denom = np.sum(ex, axis=axis, keepdims=True)
35
+ denom = np.where(denom == 0, 1.0, denom)
36
+ return ex / denom
37
+ except Exception:
38
+ _print_exc("[_softmax_safe] failed")
39
+ return np.ones_like(x) / x.shape[-1]
40
+
41
+
42
+ def _ensure_int64(feed_dict):
43
+ out = {}
44
+ for k, v in feed_dict.items():
45
+ if isinstance(v, np.ndarray) and v.dtype == np.int32:
46
+ out[k] = v.astype(np.int64)
47
+ else:
48
+ out[k] = v
49
+ return out
50
+
51
+
52
+ def _dummy_image(width=FIXED_IMG_W, height=FIXED_IMG_H):
53
+ return Image.fromarray(np.full((height, width, 3), 127, dtype=np.uint8), "RGB")
54
+
55
+
56
+ # ============================================================
57
+ # Load from HF Hub
58
+ # ============================================================
59
+ def load_from_hub():
60
+ # download model.onnx
61
+ model_path = hf_hub_download(
62
+ repo_id=REPO_ID,
63
+ filename=MODEL_FILENAME,
64
+ local_dir="hf_cache",
65
+ local_dir_use_symlinks=False,
66
+ resume_download=True,
67
+ )
68
+ # load processor (tokenizer + preproc files) from the same repo
69
+ proc = CLIPProcessor.from_pretrained(REPO_ID)
70
+ sess = ort.InferenceSession(model_path, providers=PROVIDERS)
71
+ return proc, sess
72
+
73
+
74
+ try:
75
+ processor, session = load_from_hub()
76
+ except Exception:
77
+ _print_exc("[GLOBAL INIT] Failed to download/load model from HF Hub.")
78
+ processor, session = None, None
79
+
80
+
81
+ # ============================================================
82
+ # Core helpers
83
+ # ============================================================
84
+ def _run_clip(image_pil: Image.Image, texts):
85
+ if processor is None or session is None:
86
+ return None
87
+ try:
88
+ inputs = processor(
89
+ text=texts, images=image_pil, return_tensors="np", padding=True
90
+ )
91
+ ort_inputs = _ensure_int64(inputs)
92
+ outputs = session.run(None, ort_inputs)
93
+ logits_per_image = outputs[0] # (1, n_texts)
94
+ probs = _softmax_safe(logits_per_image, axis=-1)[0]
95
+ return probs
96
+ except Exception:
97
+ _print_exc("[_run_clip] Inference failed")
98
+ return None
99
+
100
+
101
+ def detect_gender(image_pil: Image.Image) -> str:
102
+ texts = ["a man", "a woman"]
103
+ probs = _run_clip(image_pil, texts)
104
+ if probs is None:
105
+ return "unknown"
106
+ return "man" if int(np.argmax(probs)) == 0 else "woman"
107
+
108
+
109
+ def detect_age_group(image_pil: Image.Image) -> str:
110
+ texts = ["a young person", "a middle-aged person", "an old person"]
111
+ probs = _run_clip(image_pil, texts)
112
+ if probs is None:
113
+ return "unknown"
114
+ return ["young", "middle-aged", "old"][int(np.argmax(probs))]
115
+
116
+
117
+ def score_with_terms(image_pil: Image.Image, positive_terms, negative_terms):
118
+ probs_all = []
119
+ for pos, neg in zip(positive_terms, negative_terms):
120
+ probs = _run_clip(image_pil, [pos, neg])
121
+ if probs is None or len(probs) != 2:
122
+ return (
123
+ DEFAULT_OUTPUT[0],
124
+ DEFAULT_OUTPUT[1],
125
+ DEFAULT_OUTPUT[2],
126
+ DEFAULT_OUTPUT[3],
127
+ )
128
+ probs_all.append(probs)
129
+
130
+ positive_probs = [p[0] for p in probs_all]
131
+ negative_probs = [p[1] for p in probs_all]
132
+
133
+ s1 = round((probs_all[0][0] - probs_all[0][1] + 1) * 50, 2)
134
+ s2 = round((probs_all[1][0] - probs_all[1][1] + 1) * 50, 2)
135
+ s3 = round((probs_all[2][0] - probs_all[2][1] + 1) * 50, 2)
136
+
137
+ hot_score = float(np.mean(positive_probs))
138
+ ugly_score = float(np.mean(negative_probs))
139
+ composite = round(((hot_score - ugly_score) + 1) * 50, 2)
140
+
141
+ return composite, s1, s2, s3
142
+
143
+
144
+ # ============================================================
145
+ # Gradio callback
146
+ # ============================================================
147
+ def hotornot(image):
148
+ if processor is None or session is None:
149
+ return DEFAULT_OUTPUT
150
+
151
+ if image is None:
152
+ image_pil = _dummy_image()
153
+ else:
154
+ try:
155
+ image_pil = Image.fromarray(image.astype("uint8"), "RGB")
156
+ except Exception:
157
+ _print_exc("[hotornot] Failed to convert input to PIL. Using dummy image.")
158
+ image_pil = _dummy_image()
159
+
160
+ try:
161
+ gender = detect_gender(image_pil)
162
+ age_group = detect_age_group(image_pil)
163
+
164
+ if gender == "man":
165
+ positive_terms = ["a handsome man", "a charming man", "an attractive man"]
166
+ negative_terms = ["an ugly man", "a gross man", "a hideous man"]
167
+ elif gender == "woman":
168
+ positive_terms = [
169
+ "a beautiful woman",
170
+ "a cute woman",
171
+ "an attractive woman",
172
+ ]
173
+ negative_terms = ["an ugly woman", "a gross woman", "a hideous woman"]
174
+ else:
175
+ positive_terms = [
176
+ "a hot person",
177
+ "a beautiful person",
178
+ "an attractive person",
179
+ ]
180
+ negative_terms = ["an ugly person", "a gross person", "a hideous person"]
181
+
182
+ composite, hotness, second, attractiveness = score_with_terms(
183
+ image_pil, positive_terms, negative_terms
184
+ )
185
+ return composite, hotness, second, attractiveness, gender, age_group
186
+
187
+ except Exception:
188
+ _print_exc("[hotornot] Unexpected error")
189
+ return DEFAULT_OUTPUT
190
+
191
+
192
+ # ============================================================
193
+ # UI
194
+ # ============================================================
195
+ CSS = f"""
196
+ #fixed_img_component img,
197
+ #fixed_img_component canvas {{
198
+ width: {FIXED_IMG_W}px !important;
199
+ height: {FIXED_IMG_H}px !important;
200
+ object-fit: contain !important;
201
+ }}
202
+ """
203
+
204
+ with gr.Blocks(css=CSS) as demo:
205
+ gr.Markdown("# Hot or Not (CLIP ONNX from Hugging Face Hub)")
206
+ gr.Markdown(
207
+ "Loads ONNX + tokenizer from HF Hub, runs on CPU, auto-detects gender & age, and scores appearance."
208
+ )
209
+
210
+ with gr.Row():
211
+ image_in = gr.Image(
212
+ label="Upload Image",
213
+ type="numpy",
214
+ image_mode="RGB",
215
+ height=FIXED_IMG_H,
216
+ width=FIXED_IMG_W,
217
+ elem_id="fixed_img_component",
218
+ )
219
+
220
+ with gr.Row():
221
+ out_total = gr.Textbox(label="Total Hot or Not™ Score")
222
+ out_hot = gr.Textbox(label="Hotness Score")
223
+ out_mid = gr.Textbox(label="Charm / Cuteness Score")
224
+ out_attr = gr.Textbox(label="Attractiveness Score")
225
+ out_gender = gr.Textbox(label="Predicted Gender")
226
+ out_age = gr.Textbox(label="Predicted Age Group")
227
+
228
+ run_btn = gr.Button("Rate")
229
+ run_btn.click(
230
+ fn=hotornot,
231
+ inputs=[image_in],
232
+ outputs=[out_total, out_hot, out_mid, out_attr, out_gender, out_age],
233
+ )
234
+
235
+ if __name__ == "__main__":
236
+ demo.launch()