import os import base64 import json import requests import torch import numpy as np import cv2 from PIL import Image, ImageFilter from scipy.ndimage import binary_dilation from omegaconf import OmegaConf # ----------------------------- # 1) הגדרת המפתח API של Gemini כפרמטר # ----------------------------- SYSTEM_INST = """\ You are given an image. You must return information about the main character in the image. Do not write anything else beyond this! **Guidelines for identifying a character in the image:** 1. **Male:** - Infant (0–2) → "baby boy" - Toddler (2–5) → "toddler boy" - Child (6–11) → "boy" - Teenager (12–17) → "teen boy" - Young adul (18–35) → "young man" - adul (36–59) → "man" - Elderly (60+) → "elderly man" 2. **Female:** - Infant (0–2) → "baby girl" - Toddler (2–5) → "toddler girl" - Child (6–11) → "girl" - Teenager (12–17) → "teen girl" - Young adul (18–35) → "young woman" - adul (36–59) → "woman" - Elderly (60+) → "elderly woman" 3. **Unclear identification:** - Ambiguous character → "unidentified" - Ambiguous infant/toddler → "baby" or "toddler" 4. **No character in the image:** - Respond: "no person" 5. **Multiple characters:** - Identify the most central or prominent character. Notes: - If data is insufficient to classify → "insufficient data". """ conversation = [] # נשמור כאן את השיחה הנוכחית female_keywords = { "baby girl", "toddler girl", "girl", "teen girl", "young woman", "woman", "elderly woman" } def is_female_from_text(gemini_text: str) -> bool: """בודק האם התשובה מ-Gemini מצביעה על אישה לפי מילות המפתח שהוגדרו.""" return gemini_text.lower().strip() in female_keywords def encode_image_to_base64(image: Image.Image) -> str: import io buffer = io.BytesIO() image.save(buffer, format='JPEG') encoded_str = base64.b64encode(buffer.getvalue()).decode('utf-8') return encoded_str def add_user_text(message: str): conversation.append({ "role": "user", "parts": [ {"text": message} ] }) def add_user_image_from_pil(image: Image.Image, mime_type: str = "image/jpeg"): encoded_str = encode_image_to_base64(image) conversation.append({ "role": "user", "parts": [ { "inline_data": { "mime_type": mime_type, "data": encoded_str } } ] }) def send_and_receive(api_key: str) -> str: url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent" params = {"key": api_key} headers = {"Content-Type": "application/json"} payload = { "systemInstruction": { "role": "system", "parts": [ {"text": SYSTEM_INST} ] }, "contents": conversation } response = requests.post(url, params=params, headers=headers, json=payload) if response.status_code != 200: print(f"[Gemini] שגיאה בסטטוס קוד: {response.status_code}") return "NO_ANSWER" resp_json = response.json() candidates = resp_json.get("candidates", []) if not candidates: print("[Gemini] לא התקבלה תשובה.") return "NO_ANSWER" model_content = candidates[0].get("content", {}) model_parts = model_content.get("parts", []) if not model_parts: print("[Gemini] לא נמצא תוכן בתשובת המודל.") return "NO_ANSWER" model_text = model_parts[0].get("text", "").strip() conversation.append({ "role": "model", "parts": [ {"text": model_text} ] }) return model_text # ----------------------------- # 3) טעינת מודל YOLO # ----------------------------- from ultralytics import YOLO YOLO_MODEL_PATH = '../../models/yolo11m.pt' try: yolo_model = YOLO(YOLO_MODEL_PATH) yolo_model.to("cpu") print("[YOLO] מודל YOLO נטען בהצלחה.") except Exception as e: print(f"[YOLO] לא מצליח לטעון את המודל בנתיב: {YOLO_MODEL_PATH}. שגיאה: {e}") yolo_model = None TARGET_CLASS = "person" CONF_THRESHOLD = 0.2 # ----------------------------- # 4) הכנה ל-SAM2 # ----------------------------- from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor # נתיבים יחסיים ל-Space של Hugging Face SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_tiny.pt" MODEL_CFG = "configs/sam2.1/sam2.1_hiera_t.yaml" sam2_predictor = None # אתחול כ-None device = "cuda" if torch.cuda.is_available() else "cpu" def load_sam2_model(): """טוען את מודל SAM2 באופן גלובלי.""" global sam2_predictor try: # טעינת המודל sam2_model = build_sam2(MODEL_CFG, SAM2_CHECKPOINT, device=device) sam2_predictor = SAM2ImagePredictor(sam2_model) print("[SAM2] מודל SAM2 נטען בהצלחה.") except FileNotFoundError as e: print(f"[ERROR] קובץ SAM2 לא נמצא: {e}") print(f" - ודא שקובץ המודל '{SAM2_CHECKPOINT}' וקובץ הקונפיג '{MODEL_CFG}' קיימים בנתיבים הנכונים בתוך ה-Space שלך.") except Exception as e: print(f"[ERROR] שגיאה כללית בטעינת SAM2: {e}") print(f" - סוג השגיאה: {type(e).__name__}") print(f" - הודעת השגיאה: {e}") import traceback print(f" - Traceback:") traceback.print_exc() print(f" - בדוק את התאימות בין גרסאות הספריות (torch, torchvision, sam2) ואת תקינות קובץ המודל.") # טעינת המודל בעת טעינת המודול load_sam2_model() # ----------------------------- # 5) פונקציית טשטוש # ----------------------------- def blur_regions_with_mask( image: Image.Image, mask: np.ndarray, blur_radius=20, pixel_size=20, expansion_pixels=1 ): processed_image = image.copy() img_np = np.array(processed_image) structure = np.ones((expansion_pixels, expansion_pixels), dtype=bool) expanded_mask = binary_dilation(mask, structure=structure) blurred_whole = processed_image.filter(ImageFilter.GaussianBlur(radius=blur_radius)) blurred_whole_np = np.array(blurred_whole) ys, xs = np.where(expanded_mask) if len(xs) == 0 or len(ys) == 0: return processed_image x_min, x_max = xs.min(), xs.max() y_min, y_max = ys.min(), ys.max() region = blurred_whole_np[y_min:y_max, x_min:x_max] from PIL import Image as PILImage small = PILImage.fromarray(region).resize( ((x_max - x_min) // pixel_size, (y_max - y_min) // pixel_size), resample=Image.BILINEAR ) pixelated = small.resize((x_max - x_min, y_max - y_min), PILImage.NEAREST) pixelated_np = np.array(pixelated) combined = img_np.copy() mask_region = expanded_mask[y_min:y_max, x_min:x_max] combined[y_min:y_max, x_min:x_max][mask_region] = pixelated_np[mask_region] return Image.fromarray(combined) # ----------------------------- # 6) הפונקציה המרכזית # ----------------------------- def process_image( pil_image: Image.Image, gemini_api_key: str, progress_callback=None ) -> Image.Image: if not gemini_api_key: raise ValueError("מפתח API של Gemini אינו מוזן.") """ פונקציה המקבלת תמונת PIL, מפתח API של Gemini, ומחזירה את התמונה לאחר טשטוש נשים. """ if progress_callback is None: # אם לא הועברה פונקציה לעדכון התקדמות, ניצור פונקציה ריקה def progress_callback(x, desc=""): pass conversation.clear() add_user_text("Processing a new image (backend)!") # 1) שלב YOLO progress_callback(0.0, "מתחיל זיהוי אנשים (YOLO)...") if yolo_model is None: print("[process_image] מודל YOLO לא נטען כראוי.") return pil_image np_image = np.array(pil_image) results = yolo_model.predict(np_image) bboxes_person = [] for result in results: boxes = result.boxes for box in boxes: cls_name = yolo_model.names[int(box.cls)] conf = box.conf.item() if cls_name == TARGET_CLASS and conf >= CONF_THRESHOLD: x1, y1, x2, y2 = box.xyxy[0] bboxes_person.append([int(x1), int(y1), int(x2), int(y2)]) progress_callback(0.1, f"נמצאו {len(bboxes_person)} בוקסי 'person' ב-YOLO") # 2) שלב Gemini (עבור כל בוקס בנפרד) women_boxes = [] n_bboxes = len(bboxes_person) if bboxes_person else 1 for i, bbox in enumerate(bboxes_person, start=1): fraction = 0.1 + (0.5 * i / n_bboxes) # נניח חצי מההתקדמות מוקצה ל-Gemini progress_callback(fraction, f"[Gemini] בודק בוקס #{i} מתוך {len(bboxes_person)}") x1, y1, x2, y2 = bbox cropped = pil_image.crop((x1, y1, x2, y2)) add_user_image_from_pil(cropped) add_user_text("---") gemini_text = send_and_receive(gemini_api_key) if is_female_from_text(gemini_text): women_boxes.append(bbox) # 3) שלב SAM2 (עבור בוקסים של נשים) if sam2_predictor is None: print("[process_image] SAM2 לא זמין/נטען. מחזירים תמונה ללא טשטוש.") raise ValueError("SAM2 model is not loaded.") progress_callback(0.6, f"מתחיל פילוח SAM2 על {len(women_boxes)} נשים...") sam2_predictor.set_image(np.array(pil_image)) women_masks = [] n_women = len(women_boxes) if women_boxes else 1 for j, bbox in enumerate(women_boxes, start=1): fraction = 0.6 + (0.3 * j / n_women) # עדכון עד 90% progress_callback(fraction, f"[SAM2] מפלח בוקס #{j} מתוך {len(women_boxes)}") box_np = np.array([bbox]) masks, scores, _ = sam2_predictor.predict( point_coords=None, point_labels=None, box=box_np, multimask_output=False, ) if masks.ndim == 4 and masks.shape[1] == 1: mask = masks.squeeze(1)[0].astype(bool) elif masks.ndim == 3: mask = masks[0].astype(bool) else: raise ValueError(f"[SAM2] צורת masks לא צפויה: {masks.shape}") women_masks.append((bbox, mask)) # 4) שלב טשטוש progress_callback(0.9, "מתחיל טשטוש האזורים המזוהים (Blur + פיקסול)...") final_image = pil_image.copy() for (bbox, mask) in women_masks: final_image = blur_regions_with_mask(final_image, mask) progress_callback(1.0, "סיימנו! מחזירים את התוצאה הסופית.") # המרת התמונה ל-base64 encoded_image = encode_image_to_base64(final_image) return encoded_image