Spaces:
Running
Running
""" | |
Module to manage the identification model. One can load and run inference on a | |
new image. | |
""" | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import pandas as pd | |
import torch | |
from lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint | |
from lightglue.utils import numpy_image_to_torch, rbd | |
from PIL import Image | |
from utils import ( | |
extractor_type_to_extractor, | |
extractor_type_to_matcher, | |
get_scores, | |
wasserstein, | |
) | |
from viz2d import keypoints_as_pil_image, matches_as_pil_image | |
class IdentificationModel: | |
extractor: SIFT | ALIKED | DISK | SuperPoint | |
extractor_type: str | |
threshold_wasserstein: float | |
n_keypoints: int | |
matcher: LightGlue | |
features_dict: dict[str, torch.Tensor] | |
df_db: pd.DataFrame | |
def load( | |
device: torch.device, | |
filepath_features: Path, | |
filepath_db: Path, | |
extractor_type: str, | |
n_keypoints: int, | |
threshold_wasserstein: float, | |
) -> IdentificationModel: | |
""" | |
Load the IdentificationModel provided the arguments. | |
Args: | |
device (torch.device): cpu|cuda | |
filepath_features (Path): filepath to the torch cached features on the | |
dataset one wants to predict on. | |
filepath_db (Path): filepath to the csv file containing the dataset to | |
compare with. | |
extractor_type (str): in {sift, disk, aliked, superpoint}. | |
n_keypoints (int): maximum number of keypoints to extract with the extractor. | |
threshold_wasserstein (float): threshold for the wasserstein distance to consider it a match. | |
Returns: | |
IdentificationModel: an IdentificationModel instance. | |
Raises: | |
AssertionError when the extractor_type or n_keypoints are not valid. | |
""" | |
allowed_extractor_types = ["sift", "disk", "aliked", "superpoint"] | |
assert ( | |
extractor_type in allowed_extractor_types | |
), f"extractor_type should be in {allowed_extractor_types}" | |
assert 1 <= n_keypoints <= 5000, f"n_keypoints should be in range 1..5000" | |
assert ( | |
0.0 <= threshold_wasserstein <= 1.0 | |
), f"threshold_wasserstein should be in 0..1" | |
extractor = extractor_type_to_extractor( | |
device=device, | |
extractor_type=extractor_type, | |
n_keypoints=n_keypoints, | |
) | |
matcher = extractor_type_to_matcher( | |
device=device, | |
extractor_type=extractor_type, | |
) | |
features_dict = torch.load(filepath_features) | |
df_db = pd.read_csv(filepath_db) | |
return IdentificationModel( | |
extractor_type=extractor_type, | |
n_keypoints=n_keypoints, | |
extractor=extractor, | |
matcher=matcher, | |
features_dict=features_dict, | |
df_db=df_db, | |
threshold_wasserstein=threshold_wasserstein, | |
) | |
def _make_prediction_dict( | |
model: IdentificationModel, | |
indexed_matches: dict[str, dict[str, torch.Tensor]], | |
k: int = 3, | |
) -> dict[str, Any]: | |
""" | |
Return the prediction dict. Two types of predictions can be made: | |
1. A new individual | |
2. A match from the dataset | |
Args: | |
model (IdentificationModel): The identification model that was used to | |
generate the indexed_matches. | |
indexed_matches (dict[str, dict[str, torch.Tensor]]): result of running | |
predict with the model. | |
k (int): number of top k matches to return. | |
Returns: | |
type (str): new|match | |
top_k (dict): dict containing the following keys. | |
k (int): the k parameter used to get the top k matches | |
sorted (list): list of entries containing the same keys as match | |
below. | |
match (dict): dict containing the following keys if type==match. | |
pit (str): the PIT of the matched individual. | |
name (str): the name of the matched individual. | |
filepath_crop_closest (Path): the filepath to the matched individual. | |
features (torch.Tensor): LightGlue Features of the matched individual. | |
matches (torch.Tensor): LightGlue Matches of the matched individual. | |
""" | |
indexed_scores = {k: get_scores(v) for k, v in indexed_matches.items()} | |
indexed_wasserstein = {k: wasserstein(v) for k, v in indexed_scores.items()} | |
sorted_wasserstein = sorted( | |
indexed_wasserstein.items(), key=lambda item: item[1], reverse=True | |
) | |
shared_record = { | |
"indexed_matches": indexed_matches, | |
"indexed_scores": indexed_scores, | |
"indexed_wasserstein": indexed_wasserstein, | |
"sorted_wasserstein": sorted_wasserstein, | |
} | |
def to_entry(uuid: str) -> dict[str, Any]: | |
db_row = model.df_db[model.df_db["uuid"] == uuid].iloc[0] | |
return { | |
"uuid": uuid, | |
"pit": db_row["pit"], | |
"name": db_row["name"], | |
"filepath_crop": db_row["filepath_crop"], | |
"features": model.features_dict[uuid], | |
"matches": indexed_matches[uuid], | |
} | |
top_k_results = {"k": k, "sorted": []} | |
for j in range(k): | |
if len(sorted_wasserstein) > j: | |
entry = to_entry(uuid=sorted_wasserstein[j][0]) | |
top_k_results["sorted"].append(entry) | |
if not sorted_wasserstein: | |
return {"type": "new", "top_k": top_k_results, **shared_record} | |
elif model.threshold_wasserstein > sorted_wasserstein[0][1]: | |
return {"type": "new", "top_k": top_k_results, **shared_record} | |
else: | |
return { | |
"type": "match", | |
"top_k": top_k_results, | |
"match": top_k_results["sorted"][0], | |
**shared_record, | |
} | |
# FIXME: Properly run a batch inference here to make it fast on GPU. | |
def _batch_inference( | |
model: IdentificationModel, | |
feats0: dict, | |
) -> dict[str, dict[str, torch.Tensor]]: | |
""" | |
Run batch inference on feats0 with the IdentificationModel. | |
Returns an indexed_matches datastructure containing the results of each run | |
for the given uuid in the features_dict. | |
""" | |
indexed_matches = {} | |
for uuid in model.features_dict.keys(): | |
matches01 = model.matcher( | |
{"image0": feats0, "image1": model.features_dict[uuid]} | |
) | |
indexed_matches[uuid] = matches01 | |
return indexed_matches | |
def predict(model: IdentificationModel, pil_image: Image.Image, k: int = 3) -> dict: | |
""" | |
Run inference on the pil_image on all the features_dict entries from the | |
IdentificationModel. | |
Note: It will try to optimize inference depending on the available device | |
(cpu|gpu). | |
Args: | |
model (IdentificationModel): identification model to run inference with. | |
pil_image (PIL): input image to run the inference on. | |
k (int): top k matches to return. | |
Returns: | |
type (str): new|match. | |
source (dict): contains the `features` of the input image. | |
top_k (dict): dict containing the following keys. | |
k (int): the k parameter used to get the top k matches | |
sorted (list): list of entries containing the same keys as match | |
below. | |
match (dict): dict containing the following keys if type==match. | |
pit (str): the PIT of the matched individual. | |
name (str): the name of the matched individual. | |
filepath_crop (Path): the filepath to the matched individual. | |
features (torch.Tensor): LightGlue Features of the matched individual. | |
matches (torch.Tensor): LightGlue Matches of the matched individual. | |
""" | |
# Disable gradient accumulation to make inference faster | |
torch.set_grad_enabled(False) | |
torch_image = numpy_image_to_torch(np.array(pil_image)) | |
feats0 = model.extractor.extract(torch_image) | |
indexed_matches = _batch_inference(model=model, feats0=feats0) | |
prediction_dict = _make_prediction_dict( | |
model=model, | |
indexed_matches=indexed_matches, | |
k=k, | |
) | |
return {"source": {"features": feats0}, **prediction_dict} | |
def _entry_to_visualization( | |
pil_image: Image.Image, | |
source: dict[str, Any], | |
entry: dict[str, Any], | |
) -> dict[str, Image.Image]: | |
pil_image_masked_closest = Image.open(entry["filepath_crop"]) | |
torch_image0 = np.array(pil_image) | |
torch_image1 = np.array(pil_image_masked_closest) | |
torch_images = [torch_image0, torch_image1] | |
feats0 = source["features"] | |
feats1 = entry["features"] | |
matches01 = entry["matches"] | |
feats0, feats1, matches01 = [ | |
rbd(x) for x in [feats0, feats1, matches01] | |
] # remove batch dimension | |
pil_image_matches = matches_as_pil_image( | |
torch_images=torch_images, | |
feats0=feats0, | |
feats1=feats1, | |
matches01=matches01, | |
mode="column", | |
) | |
pil_image_keypoints_target = keypoints_as_pil_image( | |
torch_image=torch_image1, | |
feats=feats1, | |
ps=23, | |
) | |
return { | |
"matches": pil_image_matches, | |
"keypoints_target": pil_image_keypoints_target, | |
} | |
def generate_visualization(pil_image: Image.Image, prediction: dict) -> dict: | |
if "type" not in prediction: | |
return {} | |
elif "top_k" not in prediction: | |
return {} | |
else: | |
torch_image0 = np.array(pil_image) | |
feats0 = prediction["source"]["features"] | |
feats0 = rbd(feats0) # remove the batch dimension | |
pil_image_keypoints_source = keypoints_as_pil_image( | |
torch_image=torch_image0, | |
feats=feats0, | |
ps=23, | |
) | |
top_k_sorted = prediction["top_k"]["sorted"] | |
top_k_sorted_results = [ | |
_entry_to_visualization( | |
pil_image=pil_image, source=prediction["source"], entry=entry | |
) | |
for entry in top_k_sorted | |
] | |
if prediction["type"] in ["match", "new"]: | |
return { | |
"source": {"keypoints": pil_image_keypoints_source}, | |
"top_k": top_k_sorted_results, | |
} | |
else: | |
return {} | |