from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torchvision from matplotlib import font_manager from matplotlib.figure import Figure from matplotlib.gridspec import GridSpec from PIL import Image IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] DISTANCE_THRESHOLD_NEW_INDIVIDUAL = 0.7 def get_inverse_normalize_transform(mean, std): return torchvision.transforms.Normalize( mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std] ) def get_color( distance: float, distance_threshold_new_individual: float = DISTANCE_THRESHOLD_NEW_INDIVIDUAL, margin: float = 0.10, ) -> str: threshold_unsure = distance_threshold_new_individual * (1.0 - margin) threshold_new_individual = distance_threshold_new_individual * (1 + margin) if distance < threshold_unsure: return "green" elif distance < threshold_new_individual: return "orange" else: return "red" def draw_extrated_chip(ax, chip_image) -> None: ax.set_title("Extracted chip") ax.set_axis_off() ax.imshow(chip_image) def draw_closest_neighbors( fig: Figure, gs: GridSpec, i_start: int, k_closest_neighbors: int, indexed_k_nearest_individuals: dict, ) -> None: inv_normalize = get_inverse_normalize_transform( mean=IMAGENET_MEAN, std=IMAGENET_STD, ) neighbors = [] for bear_id, xs in indexed_k_nearest_individuals.items(): for x in xs: data = x.copy() data["bear_id"] = bear_id neighbors.append(data) nearest_neighbors = sorted( neighbors, key=lambda x: x["distance"], )[:k_closest_neighbors] for j, neighbor in enumerate(nearest_neighbors): ax = fig.add_subplot(gs[i_start, j]) distance = neighbor["distance"] bear_id = neighbor["bear_id"] dataset_image = neighbor["dataset_image"] image = inv_normalize(dataset_image).numpy() image = np.transpose(image, (1, 2, 0)) color = get_color(distance=distance) ax.set_axis_off() ax.set_title(label=f"{bear_id}: {distance:.2f}", color=color) ax.imshow(image) def draw_top_k_individuals( fig: Figure, gs: GridSpec, i_start: int, i_end: int, indexed_k_nearest_individuals: dict, bear_ids: list[str], indexed_samples: dict, ): inv_normalize = get_inverse_normalize_transform( mean=IMAGENET_MEAN, std=IMAGENET_STD, ) for i in range(i_start, i_end): for j in range(len(bear_ids)): # Draw the closest individual chips if i == i_start: ax = fig.add_subplot(gs[i, j]) bear_id = bear_ids[j] nearest_individual = indexed_k_nearest_individuals[bear_id][0] distance = nearest_individual["distance"] dataset_image = nearest_individual["dataset_image"] image = inv_normalize(dataset_image).numpy() image = np.transpose(image, (1, 2, 0)) color = get_color(distance=distance) ax.set_axis_off() ax.set_title(label=f"{bear_id}: {distance:.2f}", color=color) ax.imshow(image) # Draw random chips from the same individuals else: bear_id = bear_ids[j] idx = i - i_start - 1 if idx < len(indexed_samples[bear_id]): filepath = indexed_samples[bear_id][idx] if filepath: ax = fig.add_subplot(gs[i, j]) with Image.open(filepath) as image: ax.set_axis_off() ax.imshow(image) def bearid_ui( pil_image_chip: Image.Image, indexed_k_nearest_individuals: dict, indexed_samples: dict, save_filepath: Path, k_closest_neighbors: int = 5, ) -> None: """Main UI for identifying bears.""" chip_image = pil_image_chip # Assumption: the bear_ids are sorted by distance - if that's not something # we can rely on, we should just sort bear_ids = list(indexed_k_nearest_individuals.keys()) # Max of the number of closest_neighbors and the number of bearids ncols = max(len(bear_ids), k_closest_neighbors) # 1 row for the closest neighbors title section # 1 row for the closest neighbors # 1 row for the individuals title section # rows for the indexed_samples (radom images of a given individual) nrows = max([len(xs) for xs in indexed_samples.values()]) + 3 figsize = (3 * ncols, 3 * nrows) fig = plt.figure(constrained_layout=True, figsize=figsize) gs = GridSpec(nrows=nrows, ncols=ncols, figure=fig) font_properties_section = font_manager.FontProperties(size=35) font_properties_title = font_manager.FontProperties(size=40) # Draw closest neighbors i_closest_neighbors = 2 ax = fig.add_subplot(gs[i_closest_neighbors - 1, :]) ax.set_axis_off() ax.text( y=0.2, x=0, s="Closest faces", font_properties=font_properties_section, ) draw_closest_neighbors( fig=fig, gs=gs, i_start=i_closest_neighbors, k_closest_neighbors=k_closest_neighbors, indexed_k_nearest_individuals=indexed_k_nearest_individuals, ) # Filling out the grid with top k individuals and random samples i_top_k_individual = 4 ax = fig.add_subplot(gs[i_top_k_individual - 1, :]) ax.set_axis_off() ax.text( y=0.2, x=0, s=f"Closest {len(bear_ids)} individuals", font_properties=font_properties_section, ) draw_top_k_individuals( fig=fig, gs=gs, i_end=nrows, i_start=i_top_k_individual, indexed_k_nearest_individuals=indexed_k_nearest_individuals, bear_ids=bear_ids, indexed_samples=indexed_samples, ) plt.savefig(save_filepath, bbox_inches="tight") plt.close()