achouffe's picture
feat: add crop and identification stages
bf6c3c2 verified
import logging
import os
import shutil
import subprocess
from collections import Counter
from pathlib import Path
from typing import Any, Optional, OrderedDict
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from PIL import Image
from pytorch_metric_learning.utils.common_functions import logging
from pytorch_metric_learning.utils.inference import InferenceModel
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import v2
from ultralytics import YOLO
# TODO: move metric learning functions into their own namespace
def sample_chips_from_bearid(
bear_id: str,
df_split: pd.DataFrame,
n: int = 4,
) -> list[Path]:
xs = df_split[df_split["bear_id"] == bear_id].sample(n=n)["path"].tolist()
return [Path(x) for x in xs]
def make_indexed_samples(
bear_ids: list[str],
df_split: pd.DataFrame,
n: int = 4,
) -> dict[str, list[Path]]:
return {
bear_id: sample_chips_from_bearid(bear_id=bear_id, df_split=df_split, n=n)
for bear_id in bear_ids
}
def _aux_get_k_nearest_individuals(
model: InferenceModel,
k_neighbors: int,
k_individuals: int,
query,
id_to_label: dict,
dataset: Dataset,
) -> dict:
"""Auxiliary helper function to get k nearest individuals.
Returns a dict with the following keys:
- k_neighbors: int - number of neighbors the KNN search extends to in order to find at least k_individuals
- dataset_indices: list[int] - list of indices to call get_item on the dataset
- dataset_labels: list[int] - labels of the dataset for the given dataset_indices
- dataset_images: list[torch.tensor] - chips of the bears
- distances: list[float] - distances from the query
Note: it can return more than k_individuals as it extends progressively the
KNN search to find at least k_individuals.
"""
assert k_individuals <= 20, f"Keep a small k_individuals: {k_individuals}"
distances, indices = model.get_nearest_neighbors(query=query, k=k_neighbors)
indices_on_cpu = indices.cpu()[0].tolist()
distances_on_cpu = distances.cpu()[0].tolist()
nearest_images, nearest_ids = list(zip(*[dataset[idx] for idx in indices_on_cpu]))
bearids = [id_to_label.get(nearest_id, "unknown") for nearest_id in nearest_ids]
counter = Counter(nearest_ids)
if len(counter.keys()) >= k_individuals:
return {
"k_neighbors": k_neighbors,
"dataset_indices": indices_on_cpu,
"dataset_labels": list(nearest_ids),
"dataset_images": list(nearest_images),
"bearids": bearids,
"distances": distances_on_cpu,
}
else:
new_k_neighbors = k_neighbors * 2
return _aux_get_k_nearest_individuals(
model,
k_neighbors=new_k_neighbors,
k_individuals=k_individuals,
query=query,
id_to_label=id_to_label,
dataset=dataset,
)
def _find_cutoff_index(k: int, dataset_labels: list[str]) -> Optional[int]:
"""Returns the index for dataset_labels that retrieves exactly k
individuals."""
if not dataset_labels:
return None
else:
selected_labels = set()
cutoff_index = -1
for idx, label in enumerate(dataset_labels):
if len(selected_labels) == k:
break
else:
selected_labels.add(label)
cutoff_index = idx + 1
return cutoff_index
def get_k_nearest_individuals(
model: InferenceModel,
k: int,
query,
id_to_label: dict,
dataset: Dataset,
) -> dict:
"""Returns the k nearest individuals using the inference model and a query.
A dict is returned with the following keys:
- dataset_indices: list[int] - list of indices to call get_item on the dataset
- dataset_labels: list[int] - labels of the dataset for the given dataset_indices
- dataset_images: list[torch.tensor] - chips of the bears
- distances: list[float] - distances from the query
"""
k_neighbors = k * 5
k_individuals = k
result = _aux_get_k_nearest_individuals(
model=model,
k_neighbors=k_neighbors,
k_individuals=k_individuals,
query=query,
id_to_label=id_to_label,
dataset=dataset,
)
cutoff_index = _find_cutoff_index(
k=k,
dataset_labels=result["dataset_labels"],
)
return {
"dataset_indices": result["dataset_indices"][:cutoff_index],
"dataset_labels": result["dataset_labels"][:cutoff_index],
"dataset_images": result["dataset_images"][:cutoff_index],
"bearids": result["bearids"][:cutoff_index],
"distances": result["distances"][:cutoff_index],
}
def index_by_bearid(k_nearest_individuals: dict) -> dict:
"""Returns a dict where keys are bearid labels (eg. 'bf_480') and the
values are list of the following dict shapes:
- dataset_label: int
- dataset_image: torch.tensor
- distance: float
- dataset_index: int
"""
result = {}
for dataset_label, dataset_image, distance, bearid, dataset_index in zip(
k_nearest_individuals["dataset_labels"],
k_nearest_individuals["dataset_images"],
k_nearest_individuals["distances"],
k_nearest_individuals["bearids"],
k_nearest_individuals["dataset_indices"],
):
row = {
"dataset_label": dataset_label,
"dataset_image": dataset_image,
"distance": distance,
"dataset_index": dataset_index,
}
if bearid not in result:
result[bearid] = [row]
else:
result[bearid].append(row)
return result
def prefix_keys_with(weights: OrderedDict, prefix: str = "module.") -> OrderedDict:
"""Returns the new weights where each key is prefixed with the provided
`prefix`.
Note: Useful when using DataParallel to account for the module. prefix key.
"""
weights_copy = weights.copy()
for k, v in weights.items():
weights_copy[f"{prefix}{k}"] = v
del weights_copy[k]
return weights_copy
def load_weights(
network: torch.nn.Module,
weights_filepath: Optional[Path] = None,
weights: Optional[OrderedDict] = None,
prefix: str = "",
) -> torch.nn.Module:
"""Loads the network weights.
Returns the network.
"""
if weights:
prefixed_weights = prefix_keys_with(weights, prefix=prefix)
network.load_state_dict(state_dict=prefixed_weights)
return network
elif weights_filepath:
assert weights_filepath.exists(), f"Invalid model_filepath {weights_filepath}"
weights = torch.load(weights_filepath)
prefixed_weights = prefix_keys_with(weights, prefix=prefix)
network.load_state_dict(state_dict=prefixed_weights)
return network
else:
raise Exception(f"Should provide at least weights or weights_filepath")
class MLP(nn.Module):
# layer_sizes[0] is the dimension of the input
# layer_sizes[-1] is the dimension of the output
def __init__(self, layer_sizes, final_relu=False):
super().__init__()
layer_list = []
layer_sizes = [int(x) for x in layer_sizes]
num_layers = len(layer_sizes) - 1
final_relu_layer = num_layers if final_relu else num_layers - 1
for i in range(len(layer_sizes) - 1):
input_size = layer_sizes[i]
curr_size = layer_sizes[i + 1]
if i <= final_relu_layer:
layer_list.append(nn.ReLU(inplace=False))
layer_list.append(nn.BatchNorm1d(input_size))
layer_list.append(nn.Linear(input_size, curr_size))
self.net = nn.Sequential(*layer_list)
self.last_linear = self.net[-1]
def forward(self, x):
return self.net(x)
def check_backbone(pretrained_backbone: str) -> None:
allowed_backbones = {
"resnet18",
"resnet50",
"convnext_tiny",
"convnext_base",
"convnext_large",
"efficientnet_v2_s",
# "squeezenet1_1",
"vit_b_16",
}
assert (
pretrained_backbone in allowed_backbones
), f"pretrained_backbone {pretrained_backbone} is not implemented, only {allowed_backbones}"
def make_trunk(pretrained_backbone: str = "resnet18") -> nn.Module:
"""Returns a nn.Module with pretrained weights using a given
pretrained_backbone.
Note: The currently available backbones are resnet18, resnet50,
convnext_tiny, convnext_bas, efficientnet_v2_s, squeezenet1_1, vit_b_16
"""
check_backbone(pretrained_backbone)
if pretrained_backbone == "resnet18":
return torchvision.models.resnet18(
weights=models.ResNet18_Weights.IMAGENET1K_V1
)
elif pretrained_backbone == "resnet50":
return torchvision.models.resnet50(
weights=models.ResNet50_Weights.IMAGENET1K_V1
)
elif pretrained_backbone == "convnext_tiny":
return torchvision.models.convnext_tiny(
weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
)
elif pretrained_backbone == "convnext_base":
return torchvision.models.convnext_base(
weights=models.ConvNeXt_Base_Weights.IMAGENET1K_V1
)
elif pretrained_backbone == "convnext_large":
return torchvision.models.convnext_large(
weights=models.ConvNeXt_Large_Weights.IMAGENET1K_V1
)
elif pretrained_backbone == "efficientnet_v2_s":
return torchvision.models.efficientnet_v2_s(
weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
)
elif pretrained_backbone == "squeezenet1_1":
return torchvision.models.squeezenet1_1(
weights=models.SqueezeNet1_1_Weights.IMAGENET1K_V1
)
elif pretrained_backbone == "vit_b_16":
return torchvision.models.vit_b_16(
weights=models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
)
else:
raise Exception(f"Cannot make trunk with backbone {pretrained_backbone}")
def make_embedder(
pretrained_backbone: str,
trunk: nn.Module,
embedding_size: int,
hidden_layer_sizes: list[int],
) -> nn.Module:
check_backbone(pretrained_backbone)
if pretrained_backbone in ["resnet18", "resnet50"]:
trunk_output_size = trunk.fc.in_features
trunk.fc = nn.Identity()
return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
if pretrained_backbone in ["convnext_tiny", "convnext_base", "convnext_large"]:
trunk_output_size = trunk.classifier[-1].in_features
trunk.classifier[-1] = nn.Identity()
return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
elif pretrained_backbone == "efficientnet_v2_s":
trunk_output_size = trunk.classifier[-1].in_features
trunk.classifier[-1] = nn.Identity()
return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
elif pretrained_backbone == "vit_b_16":
trunk_output_size = trunk.heads.head.in_features
trunk.heads.head = nn.Identity()
return MLP([trunk_output_size, *hidden_layer_sizes, embedding_size])
else:
raise Exception(f"{pretrained_backbone} embedder not implemented yet")
def make_model_dict(
device: torch.device,
pretrained_backbone: str = "resnet18",
embedding_size: int = 128,
hidden_layer_sizes: list[int] = [1024],
) -> dict[str, nn.Module]:
"""
Returns a dict with the following keys:
- embedder: nn.Module - embedder model, usually an MLP.
- trunk: nn.Module - the backbone model, usually a pretrained model (like a ResNet).
"""
trunk = make_trunk(pretrained_backbone=pretrained_backbone)
embedder = make_embedder(
pretrained_backbone=pretrained_backbone,
embedding_size=embedding_size,
hidden_layer_sizes=hidden_layer_sizes,
trunk=trunk,
)
trunk = torch.nn.DataParallel(trunk.to(device))
embedder = torch.nn.DataParallel(embedder.to(device))
return {
"trunk": trunk,
"embedder": embedder,
}
class BearDataset(Dataset):
def __init__(self, dataframe, id_mapping, transform=None):
self.dataframe = dataframe
self.id_mapping = id_mapping
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
sample = self.dataframe.iloc[idx]
image_path = sample.path
bear_id = sample.bear_id
id_value = self.id_mapping.loc[self.id_mapping["label"] == bear_id, "id"].iloc[
0
]
image = Image.open(image_path)
if self.transform:
image = self.transform(image)
return image, id_value
def make_dataloaders(
batch_size: int,
df_split: pd.DataFrame,
transforms: dict,
) -> dict:
"""Returns a dict with top level keys in {dataset and loader}.
Each returns a dict with the train, val and test objects associated.
"""
df_train = df_split[df_split["split"] == "train"]
df_val = df_split[df_split["split"] == "val"]
df_test = df_split[df_split["split"] == "test"]
id_mapping = make_id_mapping(df=df_split)
train_dataset = BearDataset(
df_train,
id_mapping,
transform=transforms["train"],
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
val_dataset = BearDataset(
df_val,
id_mapping,
transform=transforms["val"],
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
)
test_dataset = BearDataset(
df_test,
id_mapping,
transform=transforms["test"],
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
)
viz_dataset = BearDataset(
df_train,
id_mapping,
transform=transforms["viz"],
)
viz_loader = DataLoader(
viz_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
full_dataset = BearDataset(
df_split,
id_mapping,
transform=transforms["val"],
)
return {
"dataset": {
"viz": viz_dataset,
"train": train_dataset,
"val": val_dataset,
"test": test_dataset,
"full": full_dataset,
},
"loader": {
"viz": viz_loader,
"train": train_loader,
"val": val_loader,
"test": test_loader,
},
}
def make_id_mapping(df: pd.DataFrame) -> pd.DataFrame:
"""Returns a dataframe that maps a bear label (eg.
bf_755) to a unique natural number (eg. 0). The dataFrame contains
two columns, namely id and label.
"""
return pd.DataFrame(
list(enumerate(df["bear_id"].unique())), columns=["id", "label"]
)
def filter_none(xs: list) -> list:
return [x for x in xs if x is not None]
def get_dtype(dtype_str: str) -> torch.dtype:
if dtype_str == "float32":
return torch.float32
elif dtype_str == "int64":
return torch.int64
else:
logging.warning(
f"dtype_str {dtype_str} not implemented, returning default value"
)
return torch.float32
def get_transforms(
data_augmentation: dict = {},
trunk_preprocessing: dict = {},
) -> dict:
"""Returns a dict containing the transforms for the following splits:
train, val, test and viz (the latter is used for batch visualization).
"""
logging.info(f"data_augmentation config: {data_augmentation}")
logging.info(f"trunk preprocessing config: {trunk_preprocessing}")
DEFAULT_CROP_SIZE = 224
crop_size = (
trunk_preprocessing.get("crop_size", DEFAULT_CROP_SIZE),
trunk_preprocessing.get("crop_size", DEFAULT_CROP_SIZE),
)
# transform to persist a batch of data as an artefact
transform_viz = transforms.Compose(
[
transforms.Resize(crop_size),
transforms.ToTensor(),
]
)
mdtype: Optional[torch.dtype] = (
get_dtype(trunk_preprocessing["values"].get("dtype", None))
if trunk_preprocessing.get("values", None)
else None
)
mscale: Optional[bool] = (
trunk_preprocessing["values"].get("scale", None)
if trunk_preprocessing.get("values", None)
else None
)
mmean: Optional[list[float]] = (
trunk_preprocessing["normalization"].get("mean", None)
if trunk_preprocessing.get("normalization", None)
else None
)
mstd: Optional[list[float]] = (
trunk_preprocessing["normalization"].get("std", None)
if trunk_preprocessing.get("normalization", None)
else None
)
hue = (
data_augmentation["colorjitter"].get("hue", 0)
if data_augmentation.get("colorjitter", 0)
else 0
)
saturation = (
data_augmentation["colorjitter"].get("saturation", 0)
if data_augmentation.get("colorjitter", 0)
else 0
)
degrees = (
data_augmentation["rotation"].get("degrees", 0)
if data_augmentation.get("rotation", 0)
else 0
)
transformations_plain = [
transforms.Resize(crop_size),
transforms.ToTensor(),
v2.ToDtype(dtype=mdtype, scale=mscale) if mdtype and mscale else None,
transforms.Normalize(mean=mmean, std=mstd) if mmean and mstd else None,
]
transformations_train = [
transforms.Resize(crop_size),
(
transforms.ColorJitter(
hue=hue,
saturation=saturation,
)
if data_augmentation.get("colorjitter", None)
else None
), # Taken from Dolphin ID
(
v2.RandomRotation(degrees=degrees)
if data_augmentation.get("rotation", None)
else None
), # Taken from Dolphin ID
transforms.ToTensor(),
v2.ToDtype(dtype=mdtype, scale=mscale) if mdtype and mscale else None,
transforms.Normalize(mean=mmean, std=mstd) if mmean and mstd else None,
]
# Filtering out None transforms
transform_plain = transforms.Compose(filter_none(transformations_plain))
transform_train = transforms.Compose(filter_none(transformations_train))
return {
"viz": transform_viz,
"train": transform_train,
"val": transform_plain,
"test": transform_plain,
}
def resize(
mask: np.ndarray,
dim: tuple[int, int],
interpolation: int = cv2.INTER_LINEAR,
):
"""Resize the mask to the provided `dim` using the interpolation method.
`dim`: (W, H) format
"""
return cv2.resize(mask, dsize=dim, interpolation=interpolation)
def crop_from_yolov8(prediction_yolov8) -> np.ndarray:
"""Given a yolov8 prediction, returns an image containing the cropped bear
head."""
H, W = prediction_yolov8.orig_shape
predictions_masks = prediction_yolov8.masks.data.to("cpu").numpy()
idx = np.argmax(prediction_yolov8.boxes.conf.to("cpu").numpy())
predictions_mask = predictions_masks[idx]
prediction_resized = resize(predictions_mask, dim=(W, H))
masked_image = prediction_yolov8.orig_img.copy()
black_pixel = [0, 0, 0]
masked_image[~prediction_resized.astype(bool)] = black_pixel
x0, y0, x1, y1 = prediction_yolov8.boxes[idx].xyxy[0].to("cpu").numpy()
return masked_image[int(y0) : int(y1), int(x0) : int(x1)]
def square_pad(img: np.ndarray):
"""Returns an image with dimension max(W, H) x max(W, H), padded with black
pixels."""
H, W, _ = img.shape
K = max(H, W)
top = (K - H) // 2
bottom = (K - H) // 2
left = (K - W) // 2
right = (K - W) // 2
return cv2.copyMakeBorder(
img.copy(),
top,
bottom,
left,
right,
cv2.BORDER_CONSTANT,
)
def get_best_device() -> torch.device:
"""Returns the best torch device depending on the hardware it is running
on."""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _setup_chips() -> None:
"""
Setup the Database of chips used for the face recognition.
"""
subprocess.run(["./scripts/chips/install.sh"])
def _setup_ml_pipeline(input_packaged_pipeline: Path, install_path: Path) -> None:
"""
Setup the ML pipeline, installing the model weights into their folders.
"""
logging.info(f"Installing the packaged pipeline in {install_path}")
os.makedirs(install_path, exist_ok=True)
packaged_pipeline_archive_filepath = input_packaged_pipeline
shutil.unpack_archive(
filename=packaged_pipeline_archive_filepath,
extract_dir=install_path,
)
metriclearning_model_filepath = install_path / "bearidentification" / "model.pt"
device = get_best_device()
bearidentification_model = torch.load(
metriclearning_model_filepath,
map_location=device,
)
df_split = pd.DataFrame(bearidentification_model["data_split"])
chips_root_dir = Path("/".join(df_split.iloc[0]["path"].split("/")[:-4]))
logging.info(f"Retrieved chips_root_dir: {chips_root_dir}")
os.makedirs(chips_root_dir, exist_ok=True)
shutil.copytree(
src=install_path / "chips",
dst=chips_root_dir,
dirs_exist_ok=True,
)
def setup(input_packaged_pipeline: Path, install_path: Path) -> None:
"""
Full setup of the project.
"""
_setup_chips()
_setup_ml_pipeline(
input_packaged_pipeline=input_packaged_pipeline, install_path=install_path
)
def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
"""
Turn a BGR numpy array into a RGB numpy array when the array `a` represents
an image.
"""
return a[:, :, ::-1]
def load_segmentation_model(filepath_weights: Path) -> YOLO:
"""
Load the YOLO model given the filepath_weights.
"""
assert filepath_weights.exists()
return YOLO(filepath_weights)
def load_metric_learning_model(device: torch.device, filepath_weights: Path) -> Any:
assert filepath_weights.exists()
return torch.load(filepath_weights, map_location=device)
def load_models(
filepath_segmentation_weights: Path,
filepath_metric_learning_weights: Path,
) -> dict[str, Any]:
assert filepath_segmentation_weights.exists()
assert filepath_metric_learning_weights.exists()
device = get_best_device()
model_segmentation = load_segmentation_model(filepath_segmentation_weights)
model_metric_learning = load_metric_learning_model(
device=device,
filepath_weights=filepath_metric_learning_weights,
)
return {
"segmentation": model_segmentation,
"metric_learning": model_metric_learning,
}
def run_segmentation(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
predictions = model(pil_image)
if len(predictions) > 0:
prediction = predictions[0]
pil_image_with_prediction = Image.fromarray(bgr_to_rgb(prediction.plot()))
return {"pil_image": pil_image_with_prediction, "prediction": prediction}
else:
return {}
def run_crop(square_dim: int, yolo_prediction) -> dict[str, Any]:
"""
Run the crop stage on the yolo_prediction.
It resizes a square bear face based on `square_dim`.
"""
cropped_bear_head = crop_from_yolov8(prediction_yolov8=yolo_prediction)
padded_cropped_head = square_pad(cropped_bear_head)
resized_padded_cropped_head = resize(
padded_cropped_head, dim=(square_dim, square_dim)
)
pil_image_cropped_bear_head = Image.fromarray(bgr_to_rgb(cropped_bear_head))
pil_image_padded_cropped_head = Image.fromarray(
bgr_to_rgb(resized_padded_cropped_head)
)
pil_image_resized_padded_cropped_head = Image.fromarray(
bgr_to_rgb(resized_padded_cropped_head)
)
return {
"pil_images": {
"cropped": pil_image_cropped_bear_head,
"padded": pil_image_padded_cropped_head,
"resized": pil_image_resized_padded_cropped_head,
}
}
def make_id_to_label(id_mapping: pd.DataFrame) -> dict[int, str]:
return id_mapping.set_index("id")["label"].to_dict()
def run_identification(
loaded_model,
k: int,
knn_index_filepath: Path,
pil_image_chip: Image.Image,
n_samples_per_individual: int = 5,
) -> dict[str, Any]:
"""
Run the identification stage.
"""
device = get_best_device()
args = loaded_model["args"]
config = args.copy()
del config["run"]
transforms = get_transforms(
data_augmentation=config.get("data_augmentation", {}),
trunk_preprocessing=config["model"]["trunk"].get("preprocessing", {}),
)
logging.info("loading the df_split")
df_split = pd.DataFrame(loaded_model["data_split"])
df_split.info()
id_mapping = make_id_mapping(df=df_split)
dataloaders = make_dataloaders(
batch_size=config["batch_size"],
df_split=df_split,
transforms=transforms,
)
model_dict = make_model_dict(
device=device,
pretrained_backbone=config["model"]["trunk"]["backbone"],
embedding_size=config["model"]["embedder"]["embedding_size"],
hidden_layer_sizes=config["model"]["embedder"]["hidden_layer_sizes"],
)
trunk_weights = loaded_model["trunk"]
trunk = model_dict["trunk"]
trunk = load_weights(
network=trunk,
weights=trunk_weights,
prefix="module.",
)
embedder_weights = loaded_model["embedder"]
embedder = model_dict["embedder"]
embedder = load_weights(
network=embedder,
weights=embedder_weights,
prefix="module.",
)
model = InferenceModel(
trunk=trunk,
embedder=embedder,
)
dataset_full = dataloaders["dataset"]["full"]
assert (
knn_index_filepath.exists()
), f"knn_index_filepath invalid filepath: {knn_index_filepath}"
model.load_knn_func(filename=str(knn_index_filepath))
image = pil_image_chip
transform_test = transforms["test"]
model_input = transform_test(image)
query = model_input.unsqueeze(0)
id_to_label = make_id_to_label(id_mapping=id_mapping)
k_nearest_individuals = get_k_nearest_individuals(
model=model,
k=k,
query=query,
id_to_label=id_to_label,
dataset=dataset_full,
)
indexed_k_nearest_individuals = index_by_bearid(
k_nearest_individuals=k_nearest_individuals
)
bear_ids = list(indexed_k_nearest_individuals.keys())
indexed_samples = make_indexed_samples(
bear_ids=bear_ids,
df_split=df_split,
n=n_samples_per_individual,
)
return {
"bear_ids": bear_ids,
"k_nearest_individuals": k_nearest_individuals,
"indexed_k_nearest_individuals": indexed_k_nearest_individuals,
"indexed_samples": indexed_samples,
}
def run_pipeline(
loaded_models: dict[str, Any],
param_square_dim: int,
param_k: int,
param_n_samples_per_individual: int,
knn_index_filepath: Path,
pil_image: Image.Image,
) -> dict[str, Any]:
"""
Run the full pipeline on pil_image, using `pil_image` as an input.
Args:
loaded_models (dict[str, Any]): dict of all the loaded models needed to
run the pipeline. Usually loaded via the `load_model` function.
param_square_dim (int): size of the square chip.
param_k (int): how many closest individuals to query to compare it to
the chip
param_n_samples_per_individual (int): How many chips from each
individual do we want to compare it to?
knn_index_filepath (Path): filepath to the KNN index of the embedded
chips.
pil_image (PIL): Main input image of the pipeline
"""
results_segmentation = run_segmentation(
model=loaded_models["segmentation"], pil_image=pil_image
)
results_crop = run_crop(
square_dim=param_square_dim,
yolo_prediction=results_segmentation["prediction"],
)
pil_image_chip = results_crop["pil_images"]["resized"]
results_identification = run_identification(
loaded_model=loaded_models["metric_learning"],
k=param_k,
knn_index_filepath=knn_index_filepath,
pil_image_chip=pil_image_chip,
n_samples_per_individual=5,
)
return {
"order": ["segmentation", "crop", "identification"],
"stages": {
"segmentation": {
"input": {"pil_image": pil_image},
"output": results_segmentation,
},
"crop": {
"input": {
"square_dim": param_square_dim,
"yolo_prediction": results_segmentation["prediction"],
},
"output": results_crop,
},
"identification": {
"input": {
"k": param_k,
"n_samples_per_individual": param_n_samples_per_individual,
"knn_index_filepath": knn_index_filepath,
"pil_image_chip": pil_image_chip,
},
"output": results_identification,
},
},
}