Spaces:
Running
Running
import pickle | |
from pathlib import Path | |
from typing import List, Optional | |
import cv2 | |
import face_recognition | |
import numpy as np | |
from PIL import Image | |
from pydantic import model_validator | |
from ....models.od.schemas import Target | |
from ....utils.registry import registry | |
from ...base import ArgSchema, BaseModelTool | |
ARGSCHEMA = {} | |
class FaceRecognition(BaseModelTool): | |
args_schema: ArgSchema = ArgSchema(**ARGSCHEMA) | |
description: str = ( | |
"This tool can recognize facial information in images/extracted frames and identify who the person is." | |
"Images/extracted frames are already loaded." | |
) | |
threshold: float = 0.6 | |
num_jitters: int = 1 | |
face_db: str = "data/face_db" | |
model: str = "large" | |
loaded_face_db: Optional[dict] = None | |
def face_db_validator(self) -> "FaceRecognition": | |
if self.loaded_face_db is None: | |
if Path(self.face_db).exists(): | |
self.loaded_face_db = self._load_face_db(self.face_db) | |
else: | |
raise ValueError(f"Face database not found at {self.face_db}") | |
elif isinstance(self.loaded_face_db, dict): | |
if ( | |
"embeddings" not in self.loaded_face_db | |
or "names" not in self.loaded_face_db | |
): | |
raise ValueError( | |
"Face database must have 'embeddings' and 'names' keys." | |
) | |
else: | |
raise ValueError("Face database must be a dictionary.") | |
return self | |
def _load_face_db(self, path: str): | |
cached_model = Path(path).joinpath(f"representations_{self.model}_face.pkl") | |
# if cached_model.exists(): | |
# loaded_face_db = pickle.load(open(cached_model, "rb")) | |
# else: | |
face_db = Path(path) | |
embeddings = [] | |
names = [] | |
for known_image in face_db.rglob("*"): | |
if known_image.suffix in [".jpg", ".png", ".webp"]: | |
loaded_image = np.array(Image.open(known_image).convert("RGB")) | |
loaded_image = cv2.cvtColor(loaded_image, cv2.COLOR_RGB2BGR) | |
known_encoding = face_recognition.face_encodings( | |
loaded_image, model="large" | |
)[0] | |
embeddings.append(known_encoding) | |
names.append(known_image.parent.name) | |
loaded_face_db = {"embeddings": embeddings, "names": names} | |
pickle.dump(loaded_face_db, open(cached_model, "wb")) | |
return loaded_face_db | |
def infer(self, image: Image.Image) -> List[Target]: | |
img = np.array(image.convert("RGB")) | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
face_locations = face_recognition.face_locations(img) | |
face_encodings = face_recognition.face_encodings(img, face_locations) | |
rec_res = [] | |
for (top, right, bottom, left), face_encoding in zip( | |
face_locations, face_encodings | |
): | |
face_distances = face_recognition.face_distance( | |
self.loaded_face_db.get("embeddings"), face_encoding | |
) | |
best_match_index = np.argmin(face_distances) | |
if face_distances[best_match_index] <= self.threshold: | |
name = self.loaded_face_db["names"][best_match_index] | |
bbox = [left, top, right, bottom] | |
rec_res.append( | |
Target(label=name, bbox=bbox, conf=face_distances[best_match_index]) | |
) | |
return rec_res | |
def _run(self): | |
names = set() | |
for key in self.stm.image_cache.keys(): | |
anno = self.infer(self.stm.image_cache[key]) | |
self.stm.image_cache[key] = self.visual_prompting( | |
self.stm.image_cache[key], anno | |
) | |
names.update([item.label for item in anno]) | |
return f"Recognized {len(names)} faces: {', '.join(names)}" | |
async def _arun(self): | |
return self._run() | |