Spaces:
Sleeping
Sleeping
import os | |
import base64 | |
import json | |
import requests | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image, ImageFilter | |
from scipy.ndimage import binary_dilation | |
from omegaconf import OmegaConf | |
# ----------------------------- | |
# 1) הגדרת המפתח API של Gemini כפרמטר | |
# ----------------------------- | |
SYSTEM_INST = """\ | |
You are given an image. You must return information about the main character in the image. | |
Do not write anything else beyond this! | |
**Guidelines for identifying a character in the image:** | |
1. **Male:** | |
- Infant (0–2) → "baby boy" | |
- Toddler (2–5) → "toddler boy" | |
- Child (6–11) → "boy" | |
- Teenager (12–17) → "teen boy" | |
- Young adul (18–35) → "young man" | |
- adul (36–59) → "man" | |
- Elderly (60+) → "elderly man" | |
2. **Female:** | |
- Infant (0–2) → "baby girl" | |
- Toddler (2–5) → "toddler girl" | |
- Child (6–11) → "girl" | |
- Teenager (12–17) → "teen girl" | |
- Young adul (18–35) → "young woman" | |
- adul (36–59) → "woman" | |
- Elderly (60+) → "elderly woman" | |
3. **Unclear identification:** | |
- Ambiguous character → "unidentified" | |
- Ambiguous infant/toddler → "baby" or "toddler" | |
4. **No character in the image:** | |
- Respond: "no person" | |
5. **Multiple characters:** | |
- Identify the most central or prominent character. | |
Notes: | |
- If data is insufficient to classify → "insufficient data". | |
""" | |
conversation = [] # נשמור כאן את השיחה הנוכחית | |
female_keywords = { | |
"baby girl", "toddler girl", "girl", | |
"teen girl", "young woman", "woman", | |
"elderly woman" | |
} | |
def is_female_from_text(gemini_text: str) -> bool: | |
"""בודק האם התשובה מ-Gemini מצביעה על אישה לפי מילות המפתח שהוגדרו.""" | |
return gemini_text.lower().strip() in female_keywords | |
def encode_image_to_base64(image: Image.Image) -> str: | |
import io | |
buffer = io.BytesIO() | |
image.save(buffer, format='JPEG') | |
encoded_str = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return encoded_str | |
def add_user_text(message: str): | |
conversation.append({ | |
"role": "user", | |
"parts": [ | |
{"text": message} | |
] | |
}) | |
def add_user_image_from_pil(image: Image.Image, mime_type: str = "image/jpeg"): | |
encoded_str = encode_image_to_base64(image) | |
conversation.append({ | |
"role": "user", | |
"parts": [ | |
{ | |
"inline_data": { | |
"mime_type": mime_type, | |
"data": encoded_str | |
} | |
} | |
] | |
}) | |
def send_and_receive(api_key: str) -> str: | |
url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent" | |
params = {"key": api_key} | |
headers = {"Content-Type": "application/json"} | |
payload = { | |
"systemInstruction": { | |
"role": "system", | |
"parts": [ | |
{"text": SYSTEM_INST} | |
] | |
}, | |
"contents": conversation | |
} | |
response = requests.post(url, params=params, headers=headers, json=payload) | |
if response.status_code != 200: | |
print(f"[Gemini] שגיאה בסטטוס קוד: {response.status_code}") | |
return "NO_ANSWER" | |
resp_json = response.json() | |
candidates = resp_json.get("candidates", []) | |
if not candidates: | |
print("[Gemini] לא התקבלה תשובה.") | |
return "NO_ANSWER" | |
model_content = candidates[0].get("content", {}) | |
model_parts = model_content.get("parts", []) | |
if not model_parts: | |
print("[Gemini] לא נמצא תוכן בתשובת המודל.") | |
return "NO_ANSWER" | |
model_text = model_parts[0].get("text", "").strip() | |
conversation.append({ | |
"role": "model", | |
"parts": [ | |
{"text": model_text} | |
] | |
}) | |
return model_text | |
# ----------------------------- | |
# 3) טעינת מודל YOLO | |
# ----------------------------- | |
from ultralytics import YOLO | |
YOLO_MODEL_PATH = '../../models/yolo11m.pt' | |
try: | |
yolo_model = YOLO(YOLO_MODEL_PATH) | |
yolo_model.to("cpu") | |
print("[YOLO] מודל YOLO נטען בהצלחה.") | |
except Exception as e: | |
print(f"[YOLO] לא מצליח לטעון את המודל בנתיב: {YOLO_MODEL_PATH}. שגיאה: {e}") | |
yolo_model = None | |
TARGET_CLASS = "person" | |
CONF_THRESHOLD = 0.2 | |
# ----------------------------- | |
# 4) הכנה ל-SAM2 | |
# ----------------------------- | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
# נתיבים יחסיים ל-Space של Hugging Face | |
SAM2_CHECKPOINT = "checkpoints/sam2.1_hiera_tiny.pt" | |
MODEL_CFG = "configs/sam2.1/sam2.1_hiera_t.yaml" | |
sam2_predictor = None # אתחול כ-None | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_sam2_model(): | |
"""טוען את מודל SAM2 באופן גלובלי.""" | |
global sam2_predictor | |
try: | |
# טעינת המודל | |
sam2_model = build_sam2(MODEL_CFG, SAM2_CHECKPOINT, device=device) | |
sam2_predictor = SAM2ImagePredictor(sam2_model) | |
print("[SAM2] מודל SAM2 נטען בהצלחה.") | |
except FileNotFoundError as e: | |
print(f"[ERROR] קובץ SAM2 לא נמצא: {e}") | |
print(f" - ודא שקובץ המודל '{SAM2_CHECKPOINT}' וקובץ הקונפיג '{MODEL_CFG}' קיימים בנתיבים הנכונים בתוך ה-Space שלך.") | |
except Exception as e: | |
print(f"[ERROR] שגיאה כללית בטעינת SAM2: {e}") | |
print(f" - סוג השגיאה: {type(e).__name__}") | |
print(f" - הודעת השגיאה: {e}") | |
import traceback | |
print(f" - Traceback:") | |
traceback.print_exc() | |
print(f" - בדוק את התאימות בין גרסאות הספריות (torch, torchvision, sam2) ואת תקינות קובץ המודל.") | |
# טעינת המודל בעת טעינת המודול | |
load_sam2_model() | |
# ----------------------------- | |
# 5) פונקציית טשטוש | |
# ----------------------------- | |
def blur_regions_with_mask( | |
image: Image.Image, | |
mask: np.ndarray, | |
blur_radius=20, | |
pixel_size=20, | |
expansion_pixels=1 | |
): | |
processed_image = image.copy() | |
img_np = np.array(processed_image) | |
structure = np.ones((expansion_pixels, expansion_pixels), dtype=bool) | |
expanded_mask = binary_dilation(mask, structure=structure) | |
blurred_whole = processed_image.filter(ImageFilter.GaussianBlur(radius=blur_radius)) | |
blurred_whole_np = np.array(blurred_whole) | |
ys, xs = np.where(expanded_mask) | |
if len(xs) == 0 or len(ys) == 0: | |
return processed_image | |
x_min, x_max = xs.min(), xs.max() | |
y_min, y_max = ys.min(), ys.max() | |
region = blurred_whole_np[y_min:y_max, x_min:x_max] | |
from PIL import Image as PILImage | |
small = PILImage.fromarray(region).resize( | |
((x_max - x_min) // pixel_size, (y_max - y_min) // pixel_size), | |
resample=Image.BILINEAR | |
) | |
pixelated = small.resize((x_max - x_min, y_max - y_min), PILImage.NEAREST) | |
pixelated_np = np.array(pixelated) | |
combined = img_np.copy() | |
mask_region = expanded_mask[y_min:y_max, x_min:x_max] | |
combined[y_min:y_max, x_min:x_max][mask_region] = pixelated_np[mask_region] | |
return Image.fromarray(combined) | |
# ----------------------------- | |
# 6) הפונקציה המרכזית | |
# ----------------------------- | |
def process_image( | |
pil_image: Image.Image, | |
gemini_api_key: str, | |
progress_callback=None | |
) -> Image.Image: | |
if not gemini_api_key: | |
raise ValueError("מפתח API של Gemini אינו מוזן.") | |
""" | |
פונקציה המקבלת תמונת PIL, מפתח API של Gemini, ומחזירה את התמונה לאחר טשטוש נשים. | |
""" | |
if progress_callback is None: | |
# אם לא הועברה פונקציה לעדכון התקדמות, ניצור פונקציה ריקה | |
def progress_callback(x, desc=""): | |
pass | |
conversation.clear() | |
add_user_text("Processing a new image (backend)!") | |
# 1) שלב YOLO | |
progress_callback(0.0, "מתחיל זיהוי אנשים (YOLO)...") | |
if yolo_model is None: | |
print("[process_image] מודל YOLO לא נטען כראוי.") | |
return pil_image | |
np_image = np.array(pil_image) | |
results = yolo_model.predict(np_image) | |
bboxes_person = [] | |
for result in results: | |
boxes = result.boxes | |
for box in boxes: | |
cls_name = yolo_model.names[int(box.cls)] | |
conf = box.conf.item() | |
if cls_name == TARGET_CLASS and conf >= CONF_THRESHOLD: | |
x1, y1, x2, y2 = box.xyxy[0] | |
bboxes_person.append([int(x1), int(y1), int(x2), int(y2)]) | |
progress_callback(0.1, f"נמצאו {len(bboxes_person)} בוקסי 'person' ב-YOLO") | |
# 2) שלב Gemini (עבור כל בוקס בנפרד) | |
women_boxes = [] | |
n_bboxes = len(bboxes_person) if bboxes_person else 1 | |
for i, bbox in enumerate(bboxes_person, start=1): | |
fraction = 0.1 + (0.5 * i / n_bboxes) # נניח חצי מההתקדמות מוקצה ל-Gemini | |
progress_callback(fraction, f"[Gemini] בודק בוקס #{i} מתוך {len(bboxes_person)}") | |
x1, y1, x2, y2 = bbox | |
cropped = pil_image.crop((x1, y1, x2, y2)) | |
add_user_image_from_pil(cropped) | |
add_user_text("---") | |
gemini_text = send_and_receive(gemini_api_key) | |
if is_female_from_text(gemini_text): | |
women_boxes.append(bbox) | |
# 3) שלב SAM2 (עבור בוקסים של נשים) | |
if sam2_predictor is None: | |
print("[process_image] SAM2 לא זמין/נטען. מחזירים תמונה ללא טשטוש.") | |
raise ValueError("SAM2 model is not loaded.") | |
progress_callback(0.6, f"מתחיל פילוח SAM2 על {len(women_boxes)} נשים...") | |
sam2_predictor.set_image(np.array(pil_image)) | |
women_masks = [] | |
n_women = len(women_boxes) if women_boxes else 1 | |
for j, bbox in enumerate(women_boxes, start=1): | |
fraction = 0.6 + (0.3 * j / n_women) # עדכון עד 90% | |
progress_callback(fraction, f"[SAM2] מפלח בוקס #{j} מתוך {len(women_boxes)}") | |
box_np = np.array([bbox]) | |
masks, scores, _ = sam2_predictor.predict( | |
point_coords=None, | |
point_labels=None, | |
box=box_np, | |
multimask_output=False, | |
) | |
if masks.ndim == 4 and masks.shape[1] == 1: | |
mask = masks.squeeze(1)[0].astype(bool) | |
elif masks.ndim == 3: | |
mask = masks[0].astype(bool) | |
else: | |
raise ValueError(f"[SAM2] צורת masks לא צפויה: {masks.shape}") | |
women_masks.append((bbox, mask)) | |
# 4) שלב טשטוש | |
progress_callback(0.9, "מתחיל טשטוש האזורים המזוהים (Blur + פיקסול)...") | |
final_image = pil_image.copy() | |
for (bbox, mask) in women_masks: | |
final_image = blur_regions_with_mask(final_image, mask) | |
progress_callback(1.0, "סיימנו! מחזירים את התוצאה הסופית.") | |
# המרת התמונה ל-base64 | |
encoded_image = encode_image_to_base64(final_image) | |
return encoded_image | |