trout-reID / identification.py
achouffe's picture
feat: initial commit
641857b verified
"""
Module to manage the identification model. One can load and run inference on a
new image.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
import torch
from lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint
from lightglue.utils import numpy_image_to_torch, rbd
from PIL import Image
from utils import (
extractor_type_to_extractor,
extractor_type_to_matcher,
get_scores,
wasserstein,
)
from viz2d import keypoints_as_pil_image, matches_as_pil_image
@dataclass
class IdentificationModel:
extractor: SIFT | ALIKED | DISK | SuperPoint
extractor_type: str
threshold_wasserstein: float
n_keypoints: int
matcher: LightGlue
features_dict: dict[str, torch.Tensor]
df_db: pd.DataFrame
def load(
device: torch.device,
filepath_features: Path,
filepath_db: Path,
extractor_type: str,
n_keypoints: int,
threshold_wasserstein: float,
) -> IdentificationModel:
"""
Load the IdentificationModel provided the arguments.
Args:
device (torch.device): cpu|cuda
filepath_features (Path): filepath to the torch cached features on the
dataset one wants to predict on.
filepath_db (Path): filepath to the csv file containing the dataset to
compare with.
extractor_type (str): in {sift, disk, aliked, superpoint}.
n_keypoints (int): maximum number of keypoints to extract with the extractor.
threshold_wasserstein (float): threshold for the wasserstein distance to consider it a match.
Returns:
IdentificationModel: an IdentificationModel instance.
Raises:
AssertionError when the extractor_type or n_keypoints are not valid.
"""
allowed_extractor_types = ["sift", "disk", "aliked", "superpoint"]
assert (
extractor_type in allowed_extractor_types
), f"extractor_type should be in {allowed_extractor_types}"
assert 1 <= n_keypoints <= 5000, f"n_keypoints should be in range 1..5000"
assert (
0.0 <= threshold_wasserstein <= 1.0
), f"threshold_wasserstein should be in 0..1"
extractor = extractor_type_to_extractor(
device=device,
extractor_type=extractor_type,
n_keypoints=n_keypoints,
)
matcher = extractor_type_to_matcher(
device=device,
extractor_type=extractor_type,
)
features_dict = torch.load(filepath_features)
df_db = pd.read_csv(filepath_db)
return IdentificationModel(
extractor_type=extractor_type,
n_keypoints=n_keypoints,
extractor=extractor,
matcher=matcher,
features_dict=features_dict,
df_db=df_db,
threshold_wasserstein=threshold_wasserstein,
)
def _make_prediction_dict(
model: IdentificationModel,
indexed_matches: dict[str, dict[str, torch.Tensor]],
k: int = 3,
) -> dict[str, Any]:
"""
Return the prediction dict. Two types of predictions can be made:
1. A new individual
2. A match from the dataset
Args:
model (IdentificationModel): The identification model that was used to
generate the indexed_matches.
indexed_matches (dict[str, dict[str, torch.Tensor]]): result of running
predict with the model.
k (int): number of top k matches to return.
Returns:
type (str): new|match
top_k (dict): dict containing the following keys.
k (int): the k parameter used to get the top k matches
sorted (list): list of entries containing the same keys as match
below.
match (dict): dict containing the following keys if type==match.
pit (str): the PIT of the matched individual.
name (str): the name of the matched individual.
filepath_crop_closest (Path): the filepath to the matched individual.
features (torch.Tensor): LightGlue Features of the matched individual.
matches (torch.Tensor): LightGlue Matches of the matched individual.
"""
indexed_scores = {k: get_scores(v) for k, v in indexed_matches.items()}
indexed_wasserstein = {k: wasserstein(v) for k, v in indexed_scores.items()}
sorted_wasserstein = sorted(
indexed_wasserstein.items(), key=lambda item: item[1], reverse=True
)
shared_record = {
"indexed_matches": indexed_matches,
"indexed_scores": indexed_scores,
"indexed_wasserstein": indexed_wasserstein,
"sorted_wasserstein": sorted_wasserstein,
}
def to_entry(uuid: str) -> dict[str, Any]:
db_row = model.df_db[model.df_db["uuid"] == uuid].iloc[0]
return {
"uuid": uuid,
"pit": db_row["pit"],
"name": db_row["name"],
"filepath_crop": db_row["filepath_crop"],
"features": model.features_dict[uuid],
"matches": indexed_matches[uuid],
}
top_k_results = {"k": k, "sorted": []}
for j in range(k):
if len(sorted_wasserstein) > j:
entry = to_entry(uuid=sorted_wasserstein[j][0])
top_k_results["sorted"].append(entry)
if not sorted_wasserstein:
return {"type": "new", "top_k": top_k_results, **shared_record}
elif model.threshold_wasserstein > sorted_wasserstein[0][1]:
return {"type": "new", "top_k": top_k_results, **shared_record}
else:
return {
"type": "match",
"top_k": top_k_results,
"match": top_k_results["sorted"][0],
**shared_record,
}
# FIXME: Properly run a batch inference here to make it fast on GPU.
def _batch_inference(
model: IdentificationModel,
feats0: dict,
) -> dict[str, dict[str, torch.Tensor]]:
"""
Run batch inference on feats0 with the IdentificationModel.
Returns an indexed_matches datastructure containing the results of each run
for the given uuid in the features_dict.
"""
indexed_matches = {}
for uuid in model.features_dict.keys():
matches01 = model.matcher(
{"image0": feats0, "image1": model.features_dict[uuid]}
)
indexed_matches[uuid] = matches01
return indexed_matches
def predict(model: IdentificationModel, pil_image: Image.Image, k: int = 3) -> dict:
"""
Run inference on the pil_image on all the features_dict entries from the
IdentificationModel.
Note: It will try to optimize inference depending on the available device
(cpu|gpu).
Args:
model (IdentificationModel): identification model to run inference with.
pil_image (PIL): input image to run the inference on.
k (int): top k matches to return.
Returns:
type (str): new|match.
source (dict): contains the `features` of the input image.
top_k (dict): dict containing the following keys.
k (int): the k parameter used to get the top k matches
sorted (list): list of entries containing the same keys as match
below.
match (dict): dict containing the following keys if type==match.
pit (str): the PIT of the matched individual.
name (str): the name of the matched individual.
filepath_crop (Path): the filepath to the matched individual.
features (torch.Tensor): LightGlue Features of the matched individual.
matches (torch.Tensor): LightGlue Matches of the matched individual.
"""
# Disable gradient accumulation to make inference faster
torch.set_grad_enabled(False)
torch_image = numpy_image_to_torch(np.array(pil_image))
feats0 = model.extractor.extract(torch_image)
indexed_matches = _batch_inference(model=model, feats0=feats0)
prediction_dict = _make_prediction_dict(
model=model,
indexed_matches=indexed_matches,
k=k,
)
return {"source": {"features": feats0}, **prediction_dict}
def _entry_to_visualization(
pil_image: Image.Image,
source: dict[str, Any],
entry: dict[str, Any],
) -> dict[str, Image.Image]:
pil_image_masked_closest = Image.open(entry["filepath_crop"])
torch_image0 = np.array(pil_image)
torch_image1 = np.array(pil_image_masked_closest)
torch_images = [torch_image0, torch_image1]
feats0 = source["features"]
feats1 = entry["features"]
matches01 = entry["matches"]
feats0, feats1, matches01 = [
rbd(x) for x in [feats0, feats1, matches01]
] # remove batch dimension
pil_image_matches = matches_as_pil_image(
torch_images=torch_images,
feats0=feats0,
feats1=feats1,
matches01=matches01,
mode="column",
)
pil_image_keypoints_target = keypoints_as_pil_image(
torch_image=torch_image1,
feats=feats1,
ps=23,
)
return {
"matches": pil_image_matches,
"keypoints_target": pil_image_keypoints_target,
}
def generate_visualization(pil_image: Image.Image, prediction: dict) -> dict:
if "type" not in prediction:
return {}
elif "top_k" not in prediction:
return {}
else:
torch_image0 = np.array(pil_image)
feats0 = prediction["source"]["features"]
feats0 = rbd(feats0) # remove the batch dimension
pil_image_keypoints_source = keypoints_as_pil_image(
torch_image=torch_image0,
feats=feats0,
ps=23,
)
top_k_sorted = prediction["top_k"]["sorted"]
top_k_sorted_results = [
_entry_to_visualization(
pil_image=pil_image, source=prediction["source"], entry=entry
)
for entry in top_k_sorted
]
if prediction["type"] in ["match", "new"]:
return {
"source": {"keypoints": pil_image_keypoints_source},
"top_k": top_k_sorted_results,
}
else:
return {}