NHLOCAL commited on
Commit
af8184f
ยท
1 Parent(s): 9de88c1

Add application file

Browse files
__pycache__/backend.cpython-310.pyc ADDED
Binary file (8.34 kB). View file
 
__pycache__/webui.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ from backend import process_image
5
+
6
+ def inference(image: Image.Image, gemini_api_key: str):
7
+ """
8
+ ืคื•ื ืงืฆื™ื” ืฉืžื‘ืฆืขืช ื–ื™ื”ื•ื™ ื•ื˜ืฉื˜ื•ืฉ ื ืฉื™ื ื‘ืชืžื•ื ื”,
9
+ ื•ืžืขื“ื›ื ืช ืืช ืกืจื’ืœ ื”ื”ืชืงื“ืžื•ืช ื‘ื”ืชืื.
10
+ """
11
+ if not gemini_api_key.strip():
12
+ raise gr.Error("ืื ื ื”ื›ื ืก/ื™ ืžืคืชื— API ืฉืœ Gemini ืขืœ ืžื ืช ืœื”ืžืฉื™ืš.")
13
+
14
+ progress = gr.Progress() # ืื•ื‘ื™ื™ืงื˜ ืœืขื“ื›ื•ืŸ ื”ื”ืชืงื“ืžื•ืช
15
+
16
+ def progress_callback(fraction, description=""):
17
+ """
18
+ ืคื•ื ืงืฆื™ื” ืคื ื™ืžื™ืช ืฉืชื™ืงืจื ืž-backend ื‘ื›ืœ ืฉืœื‘.
19
+ fraction - ืขืจืš ื‘ื™ืŸ 0 ืœ-1 (ืœื“ื•ื’ืžื” 0.3 = 30%)
20
+ description - ืžืœืœ ืœื”ืกื‘ืจ ื”ืฉืœื‘
21
+ """
22
+ progress(fraction, desc=description)
23
+
24
+ # ื›ืขืช ื ืงืจื ืœ-process_image ืขื ืืคืฉืจื•ืช ืœืขื“ื›ืŸ ื”ืชืงื“ืžื•ืช
25
+ result_image = process_image(image, gemini_api_key, progress_callback=progress_callback)
26
+ return result_image
27
+
28
+
29
+ title_str = "ื–ื™ื”ื•ื™ ื•ื˜ืฉื˜ื•ืฉ ื ืฉื™ื ื‘ืชืžื•ื ื”"
30
+ description_str = """<p style='text-align: right; direction: rtl'>
31
+ ื”ืขืœื” ืชืžื•ื ื”, ื”ื›ื ืก ืืช ืžืคืชื— ื”ึพAPI ืฉืœ Gemini,<br>
32
+ ื•ืœื—ืฅ ืขืœ "ื”ืจืฅ" ื›ื“ื™ ืœื–ื”ื•ืช ื•ืœื˜ืฉื˜ืฉ ื ืฉื™ื ื‘ืชืžื•ื ื” ื‘ืื•ืคืŸ ืื•ื˜ื•ืžื˜ื™.
33
+ </p>
34
+ """
35
+
36
+ # ื ืชื™ื‘ ืœืชืžื•ื ืช ื“ื•ื’ืžื”
37
+ EXAMPLE_IMAGE = "example_images/example.jpg"
38
+
39
+ demo = gr.Interface(
40
+ fn=inference,
41
+ inputs=[
42
+ gr.Image(type="pil", label="ื‘ื—ืจ/ื™ ืชืžื•ื ื” ืœื ื™ืชื•ื— ืื• ื’ืจื•ืจ/ื™ ืื•ืชื” ืœื›ืืŸ"),
43
+ gr.Textbox(
44
+ label="ืžืคืชื— API ืฉืœ Gemini",
45
+ placeholder="ื”ื›ื ืก/ื™ ืืช ืžืคืชื— ื”-API ืฉืœืš ื›ืืŸ",
46
+ type="password"
47
+ )
48
+ ],
49
+ outputs=gr.Image(type="pil", label="ืชื•ืฆืื” ืกื•ืคื™ืช"),
50
+ title=title_str,
51
+ description=description_str,
52
+ examples=[
53
+ [EXAMPLE_IMAGE] # ืชืžื•ื ื” ื‘ืœื‘ื“, ืœืœื ืžืคืชื— API
54
+ ],
55
+ allow_flagging="never",
56
+ theme="compact" # ืขื™ืฆื•ื‘ ืงืœื™ืœ ืœืžืžืฉืง
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ # ื ื™ืชืŸ ืœื”ื’ื“ื™ืจ share=True ืื ืจื•ืฆื™ื ืœืฉืชืฃ ืžื—ื•ืฅ ืœืจืฉืช ื”ืžืงื•ืžื™ืช
62
+ demo.launch(server_name="127.0.0.1", server_port=7860, debug=True)
63
+
backend.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import json
4
+ import requests
5
+ import torch
6
+ import numpy as np
7
+ import cv2
8
+ from PIL import Image, ImageFilter
9
+ from scipy.ndimage import binary_dilation
10
+
11
+ # -----------------------------
12
+ # 1) ื”ื’ื“ืจืช ื”ืžืคืชื— API ืฉืœ Gemini ื›ืคืจืžื˜ืจ
13
+ # -----------------------------
14
+
15
+ SYSTEM_INST = """\
16
+ You are given an image. You must return information about the main character in the image.
17
+ Do not write anything else beyond this!
18
+
19
+ **Guidelines for identifying a character in the image:**
20
+ 1. **Male:**
21
+ - Infant (0โ€“2) โ†’ "baby boy"
22
+ - Toddler (2โ€“5) โ†’ "toddler boy"
23
+ - Child (6โ€“11) โ†’ "boy"
24
+ - Teenager (12โ€“17) โ†’ "teen boy"
25
+ - Young adul (18โ€“35) โ†’ "young man"
26
+ - adul (36โ€“59) โ†’ "man"
27
+ - Elderly (60+) โ†’ "elderly man"
28
+
29
+ 2. **Female:**
30
+ - Infant (0โ€“2) โ†’ "baby girl"
31
+ - Toddler (2โ€“5) โ†’ "toddler girl"
32
+ - Child (6โ€“11) โ†’ "girl"
33
+ - Teenager (12โ€“17) โ†’ "teen girl"
34
+ - Young adul (18โ€“35) โ†’ "young woman"
35
+ - adul (36โ€“59) โ†’ "woman"
36
+ - Elderly (60+) โ†’ "elderly woman"
37
+
38
+ 3. **Unclear identification:**
39
+ - Ambiguous character โ†’ "unidentified"
40
+ - Ambiguous infant/toddler โ†’ "baby" or "toddler"
41
+
42
+ 4. **No character in the image:**
43
+ - Respond: "no person"
44
+
45
+ 5. **Multiple characters:**
46
+ - Identify the most central or prominent character.
47
+
48
+ Notes:
49
+ - If data is insufficient to classify โ†’ "insufficient data".
50
+ """
51
+
52
+ conversation = [] # ื ืฉืžื•ืจ ื›ืืŸ ืืช ื”ืฉื™ื—ื” ื”ื ื•ื›ื—ื™ืช
53
+
54
+ female_keywords = {
55
+ "baby girl", "toddler girl", "girl",
56
+ "teen girl", "young woman", "woman",
57
+ "elderly woman"
58
+ }
59
+
60
+ def is_female_from_text(gemini_text: str) -> bool:
61
+ """ื‘ื•ื“ืง ื”ืื ื”ืชืฉื•ื‘ื” ืž-Gemini ืžืฆื‘ื™ืขื” ืขืœ ืื™ืฉื” ืœืคื™ ืžื™ืœื•ืช ื”ืžืคืชื— ืฉื”ื•ื’ื“ืจื•."""
62
+ return gemini_text.lower().strip() in female_keywords
63
+
64
+
65
+ def encode_image_to_base64(image: Image.Image) -> str:
66
+ import io
67
+ buffer = io.BytesIO()
68
+ image.save(buffer, format='JPEG')
69
+ encoded_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
70
+ return encoded_str
71
+
72
+
73
+ def add_user_text(message: str):
74
+ conversation.append({
75
+ "role": "user",
76
+ "parts": [
77
+ {"text": message}
78
+ ]
79
+ })
80
+
81
+
82
+ def add_user_image_from_pil(image: Image.Image, mime_type: str = "image/jpeg"):
83
+ encoded_str = encode_image_to_base64(image)
84
+ conversation.append({
85
+ "role": "user",
86
+ "parts": [
87
+ {
88
+ "inline_data": {
89
+ "mime_type": mime_type,
90
+ "data": encoded_str
91
+ }
92
+ }
93
+ ]
94
+ })
95
+
96
+
97
+ def send_and_receive(api_key: str) -> str:
98
+ url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
99
+ params = {"key": api_key}
100
+ headers = {"Content-Type": "application/json"}
101
+
102
+ payload = {
103
+ "systemInstruction": {
104
+ "role": "system",
105
+ "parts": [
106
+ {"text": SYSTEM_INST}
107
+ ]
108
+ },
109
+ "contents": conversation
110
+ }
111
+
112
+ response = requests.post(url, params=params, headers=headers, json=payload)
113
+ if response.status_code != 200:
114
+ print(f"[Gemini] ืฉื’ื™ืื” ื‘ืกื˜ื˜ื•ืก ืงื•ื“: {response.status_code}")
115
+ return "NO_ANSWER"
116
+
117
+ resp_json = response.json()
118
+ candidates = resp_json.get("candidates", [])
119
+ if not candidates:
120
+ print("[Gemini] ืœื ื”ืชืงื‘ืœื” ืชืฉื•ื‘ื”.")
121
+ return "NO_ANSWER"
122
+
123
+ model_content = candidates[0].get("content", {})
124
+ model_parts = model_content.get("parts", [])
125
+ if not model_parts:
126
+ print("[Gemini] ืœื ื ืžืฆื ืชื•ื›ืŸ ื‘ืชืฉื•ื‘ืช ื”ืžื•ื“ืœ.")
127
+ return "NO_ANSWER"
128
+
129
+ model_text = model_parts[0].get("text", "").strip()
130
+ conversation.append({
131
+ "role": "model",
132
+ "parts": [
133
+ {"text": model_text}
134
+ ]
135
+ })
136
+ return model_text
137
+
138
+
139
+ # -----------------------------
140
+ # 3) ื˜ืขื™ื ืช ืžื•ื“ืœ YOLO
141
+ # -----------------------------
142
+ from ultralytics import YOLO
143
+ YOLO_MODEL_PATH = '../../models/yolo11m.pt'
144
+
145
+ try:
146
+ yolo_model = YOLO(YOLO_MODEL_PATH)
147
+ yolo_model.to("cpu")
148
+ except Exception as e:
149
+ print(f"[YOLO] ืœื ืžืฆืœื™ื— ืœื˜ืขื•ืŸ ืืช ื”ืžื•ื“ืœ ื‘ื ืชื™ื‘: {YOLO_MODEL_PATH}")
150
+ yolo_model = None
151
+
152
+ TARGET_CLASS = "person"
153
+ CONF_THRESHOLD = 0.2
154
+
155
+ # -----------------------------
156
+ # 4) ื”ื›ื ื” ืœ-SAM2
157
+ # -----------------------------
158
+ try:
159
+ from hydra import initialize
160
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
161
+
162
+ SAM2_CONFIG_PATH = "../../models/sam2.1/"
163
+ SAM2_MODEL_NAME = "facebook/sam2.1-hiera-tiny"
164
+
165
+ sam2_predictor = None
166
+ device = "cuda" if torch.cuda.is_available() else "cpu"
167
+
168
+ with initialize(config_path=SAM2_CONFIG_PATH):
169
+ sam2_predictor = SAM2ImagePredictor.from_pretrained(SAM2_MODEL_NAME)
170
+ sam2_predictor.model.to(device)
171
+
172
+ except Exception as e:
173
+ print("[SAM2] ืœื ืžืฆืœื™ื— ืœื˜ืขื•ืŸ ืืช SAM2. ื•ื“ื ืฉื”ื ืชื™ื‘ ื•ื”ืงื•ื ืคื™ื’ ื ื›ื•ื ื™ื.")
174
+ sam2_predictor = None
175
+
176
+ # -----------------------------
177
+ # 5) ืคื•ื ืงืฆื™ื™ืช ื˜ืฉื˜ื•ืฉ
178
+ # -----------------------------
179
+ def blur_regions_with_mask(
180
+ image: Image.Image,
181
+ mask: np.ndarray,
182
+ blur_radius=20,
183
+ pixel_size=20,
184
+ expansion_pixels=1
185
+ ):
186
+ processed_image = image.copy()
187
+ img_np = np.array(processed_image)
188
+
189
+ structure = np.ones((expansion_pixels, expansion_pixels), dtype=bool)
190
+ expanded_mask = binary_dilation(mask, structure=structure)
191
+
192
+ blurred_whole = processed_image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
193
+ blurred_whole_np = np.array(blurred_whole)
194
+
195
+ ys, xs = np.where(expanded_mask)
196
+ if len(xs) == 0 or len(ys) == 0:
197
+ return processed_image
198
+
199
+ x_min, x_max = xs.min(), xs.max()
200
+ y_min, y_max = ys.min(), ys.max()
201
+
202
+ region = blurred_whole_np[y_min:y_max, x_min:x_max]
203
+
204
+ from PIL import Image as PILImage
205
+ small = PILImage.fromarray(region).resize(
206
+ ((x_max - x_min) // pixel_size, (y_max - y_min) // pixel_size),
207
+ resample=Image.BILINEAR
208
+ )
209
+ pixelated = small.resize((x_max - x_min, y_max - y_min), PILImage.NEAREST)
210
+ pixelated_np = np.array(pixelated)
211
+
212
+ combined = img_np.copy()
213
+ mask_region = expanded_mask[y_min:y_max, x_min:x_max]
214
+ combined[y_min:y_max, x_min:x_max][mask_region] = pixelated_np[mask_region]
215
+
216
+ return Image.fromarray(combined)
217
+
218
+
219
+ # -----------------------------
220
+ # 6) ื”ืคื•ื ืงืฆื™ื” ื”ืžืจื›ื–ื™ืช
221
+ # -----------------------------
222
+ def process_image(
223
+ pil_image: Image.Image,
224
+ gemini_api_key: str,
225
+ progress_callback=None
226
+ ) -> Image.Image:
227
+ """
228
+ ืคื•ื ืงืฆื™ื” ื”ืžืงื‘ืœืช ืชืžื•ื ืช PIL, ืžืคืชื— API ืฉืœ Gemini, ื•ืžื—ื–ื™ืจื” ืืช ื”ืชืžื•ื ื” ืœืื—ืจ ื˜ืฉื˜ื•ืฉ ื ืฉื™ื,
229
+ ืชื•ืš ืฉืœื‘ื™ ื”ืชืงื“ืžื•ืช ืžื•ื’ื“ืจื™ื:
230
+ - ื–ื™ื”ื•ื™ ืื ืฉื™ื ื‘-YOLO
231
+ - ื–ื™ื”ื•ื™ ืื ืื™ืฉื” ื‘ืขื–ืจืช Gemini
232
+ - ืคื™ืœื•ื— ื‘ืืžืฆืขื•ืช SAM2
233
+ - ื˜ืฉื˜ื•ืฉ
234
+ ืคืจืžื˜ืจ progress_callback: ืคื•ื ืงืฆื™ื” ืœืงื‘ืœืช (fraction, description)
235
+ """
236
+
237
+ if progress_callback is None:
238
+ # ืื ืœื ื”ื•ืขื‘ืจื” ืคื•ื ืงืฆื™ื” ืœืขื“ื›ื•ืŸ ื”ืชืงื“ืžื•ืช, ื ื™ืฆื•ืจ ืคื•ื ืงืฆื™ื” ืจื™ืงื”
239
+ def progress_callback(x, desc=""):
240
+ pass
241
+
242
+ conversation.clear()
243
+ add_user_text("Processing a new image (backend)!")
244
+
245
+ # 1) ืฉืœื‘ YOLO
246
+ progress_callback(0.0, "ืžืชื—ื™ืœ ื–ื™ื”ื•ื™ ืื ืฉื™ื (YOLO)...")
247
+ if yolo_model is None:
248
+ print("[process_image] ืžื•ื“ืœ YOLO ืœื ื ื˜ืขืŸ ื›ืจืื•ื™.")
249
+ return pil_image
250
+
251
+ np_image = np.array(pil_image)
252
+ results = yolo_model.predict(np_image)
253
+ bboxes_person = []
254
+
255
+ for result in results:
256
+ boxes = result.boxes
257
+ for box in boxes:
258
+ cls_name = yolo_model.names[int(box.cls)]
259
+ conf = box.conf.item()
260
+ if cls_name == TARGET_CLASS and conf >= CONF_THRESHOLD:
261
+ x1, y1, x2, y2 = box.xyxy[0]
262
+ bboxes_person.append([int(x1), int(y1), int(x2), int(y2)])
263
+
264
+ progress_callback(0.1, f"ื ืžืฆืื• {len(bboxes_person)} ื‘ื•ืงืกื™ 'person' ื‘-YOLO")
265
+
266
+ # 2) ืฉืœื‘ Gemini (ืขื‘ื•ืจ ื›ืœ ื‘ื•ืงืก ื‘ื ืคืจื“)
267
+ women_boxes = []
268
+ n_bboxes = len(bboxes_person) if bboxes_person else 1
269
+ for i, bbox in enumerate(bboxes_person, start=1):
270
+ fraction = 0.1 + (0.5 * i / n_bboxes) # ื ื ื™ื— ื—ืฆื™ ืžื”ื”ืชืงื“ืžื•ืช ืžื•ืงืฆื” ืœ-Gemini
271
+ progress_callback(fraction, f"[Gemini] ื‘ื•ื“ืง ื‘ื•ืงืก #{i} ืžืชื•ืš {len(bboxes_person)}")
272
+
273
+ x1, y1, x2, y2 = bbox
274
+ cropped = pil_image.crop((x1, y1, x2, y2))
275
+
276
+ add_user_image_from_pil(cropped)
277
+ add_user_text("---")
278
+
279
+ gemini_text = send_and_receive(gemini_api_key)
280
+ if is_female_from_text(gemini_text):
281
+ women_boxes.append(bbox)
282
+
283
+ # 3) ืฉืœื‘ SAM2 (ืขื‘ื•ืจ ื‘ื•ืงืกื™ื ืฉืœ ื ืฉื™ื)
284
+ if sam2_predictor is None:
285
+ print("[process_image] SAM2 ืœื ื–ืžื™ืŸ/ื ื˜ืขืŸ. ืžื—ื–ื™ืจื™ื ืชืžื•ื ื” ืœืœื ื˜ืฉื˜ื•ืฉ.")
286
+ return pil_image
287
+
288
+ progress_callback(0.6, f"ืžืชื—ื™ืœ ืคื™ืœื•ื— SAM2 ืขืœ {len(women_boxes)} ื ืฉื™ื...")
289
+ sam2_predictor.set_image(np.array(pil_image))
290
+
291
+ women_masks = []
292
+ n_women = len(women_boxes) if women_boxes else 1
293
+ for j, bbox in enumerate(women_boxes, start=1):
294
+ fraction = 0.6 + (0.3 * j / n_women) # ืขื“ื›ื•ืŸ ืขื“ 90%
295
+ progress_callback(fraction, f"[SAM2] ืžืคืœื— ื‘ื•ืงืก #{j} ืžืชื•ืš {len(women_boxes)}")
296
+
297
+ box_np = np.array([bbox])
298
+ masks, scores, _ = sam2_predictor.predict(
299
+ point_coords=None,
300
+ point_labels=None,
301
+ box=box_np,
302
+ multimask_output=False,
303
+ )
304
+
305
+ if masks.ndim == 4 and masks.shape[1] == 1:
306
+ mask = masks.squeeze(1)[0].astype(bool)
307
+ elif masks.ndim == 3:
308
+ mask = masks[0].astype(bool)
309
+ else:
310
+ raise ValueError(f"[SAM2] ืฆื•ืจืช masks ืœื ืฆืคื•ื™ื”: {masks.shape}")
311
+
312
+ women_masks.append((bbox, mask))
313
+
314
+ # 4) ืฉืœื‘ ื˜ืฉื˜ื•ืฉ
315
+ progress_callback(0.9, "ืžืชื—ื™ืœ ื˜ืฉื˜ื•ืฉ ื”ืื–ื•ืจื™ื ื”ืžื–ื•ื”ื™ื (Blur + ืคื™ืงืกื•ืœ)...")
316
+ final_image = pil_image.copy()
317
+ for (bbox, mask) in women_masks:
318
+ final_image = blur_regions_with_mask(final_image, mask)
319
+
320
+ progress_callback(1.0, "ืกื™ื™ืžื ื•! ืžื—ื–ื™ืจื™ื ืืช ื”ืชื•ืฆืื” ื”ืกื•ืคื™ืช.")
321
+ return final_image
example_images/example.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ numpy
4
+ opencv-python
5
+ Pillow
6
+ requests
7
+ ultralytics
8
+ scipy
9
+ hydra-core
10
+ git+https://github.com/facebookresearch/sam2.git