import math from enum import Enum from pathlib import Path from typing import Tuple import matplotlib.pyplot as plt import numpy as np import torch from lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint from PIL import Image from scipy.stats import wasserstein_distance def select_best_device() -> torch.device: """ Select best available device (cpu or cuda) based on availability. """ if torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") def bgr_to_rgb(a: np.ndarray) -> np.ndarray: """ Turn a BGR numpy array into a RGB numpy array. """ return a[:, :, ::-1] ALLOWED_EXTRACTOR_TYPES = ["sift", "disk", "superpoint", "aliked"] def extractor_type_to_extractor( device: torch.device, extractor_type: str, n_keypoints: int = 1024, ): """ Given an extractor_type in {'sift', 'superpoint', 'aliked', 'disk'}, returns a LightGlue extractor. Args: device (torch.device): cpu or cuda extractor_type (str): in {sift, superpoint, aliked, disk} n_keypoints (int): number of max keypoints to generate with the extractor. The higher the better accuracy but the longer. Returns: LigthGlueExtractor: ALIKED | DISK | SIFT | SuperPoint Raises: AssertionError: when the n_keypoints are outside the valid range 0..5000 AssertionError: when extractor_type is not valid """ assert 0 <= n_keypoints <= 5000, "n_keypoints should be in range 0..5000" assert ( extractor_type in ALLOWED_EXTRACTOR_TYPES ), f"extractor type {extractor_type} should be in {ALLOWED_EXTRACTOR_TYPES}." if extractor_type == "sift": return SIFT(max_num_keypoints=n_keypoints).eval().to(device) elif extractor_type == "superpoint": return SuperPoint(max_num_keypoints=n_keypoints).eval().to(device) elif extractor_type == "disk": return DISK(max_num_keypoints=n_keypoints).eval().to(device) elif extractor_type == "aliked": return ALIKED(max_num_keypoints=n_keypoints).eval().to(device) else: raise Exception("extractor_type is not valid") def extractor_type_to_matcher(device: torch.device, extractor_type: str) -> LightGlue: """ Return the LightGlue matcher given an `extractor_type`. Args: device (torch.device): cpu or cuda extractor_type (str): in {sift, superpoint, aliked, disk} Returns: LightGlue Matcher """ assert ( extractor_type in ALLOWED_EXTRACTOR_TYPES ), f"extractor type {extractor_type} should be in {ALLOWED_EXTRACTOR_TYPES}." return LightGlue(features=extractor_type).eval().to(device) def get_scores(matches: dict[str, torch.Tensor]) -> np.ndarray: """ Given a `matches` dict from the LightGlue matcher output, it returns the scores as a numpy array. """ return matches["matching_scores0"][0].to("cpu").numpy() def wasserstein(scores: np.ndarray) -> float: """ Return the Wasserstein distance of the scores against the null distribution. The greater the distance, the farther away it is from the null distribution. """ x_null_distribution = [0.0] * 1024 return wasserstein_distance(x_null_distribution, scores).item() class PictureLayout(Enum): """ Layout of a picture. """ PORTRAIT = "portrait" LANDSCAPE = "landscape" SQUARE = "square" def crop( pil_image: Image.Image, box: Tuple[int, int, int, int], ) -> Image.Image: """ Crop a pil_image based on the provided rectangle in (x1, y1, x2, y2) format - with the upper left corner given first. """ return pil_image.crop(box=box) def get_picture_layout(pil_image: Image.Image) -> PictureLayout: """ Return the picture layout. """ width, height = pil_image.size if width > height: return PictureLayout.LANDSCAPE elif width == height: return PictureLayout.SQUARE else: return PictureLayout.PORTRAIT def get_segmentation_mask_crop_box( pil_image_mask: Image.Image, padding: int = 0, ) -> Tuple[int, int, int, int]: """ Return a crop box for the given pil_image that contains the segmentation mask (black and white). Args: pil_image_mask (PIL): image containing the segmentation mask padding (int): how much to pad around the segmentation mask. Returns: Rectangle (Tuple[int, int, int, int]): 4 tuple representing a rectangle (x1, y1, x2, y2) with the upper left corner given first. """ array_image_mask = np.array(pil_image_mask) a = np.where(array_image_mask != 0) y_min = np.min(a[0]).item() y_max = np.max(a[0]).item() x_min = np.min(a[1]).item() x_max = np.max(a[1]).item() box = (x_min, y_min, x_max, y_max) box_with_padding = ( box[0] - padding, box[1] - padding, box[2] + padding, box[3] + padding, ) return box_with_padding def scale_keypoints_to_image_size( image_width: int, image_height: int, keypoints_xyn: np.ndarray, ) -> np.ndarray: """ Given keypoints in xyn format, it returns new keypoints in xy format. Args: image_width (int): width of the image image_height (int): height of the image keypoints_xyn (np.ndarray): 2D numpy array representing the keypoints in xyn format. Returns: keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in xy format. """ keypoints_xy = keypoints_xyn.copy() keypoints_xy[:, 0] = keypoints_xyn[:, 0] * image_width keypoints_xy[:, 1] = keypoints_xyn[:, 1] * image_height return keypoints_xy def normalize_keypoints_to_image_size( image_width: int, image_height: int, keypoints_xy: np.ndarray, ) -> np.ndarray: """ Given keypoints in xy format, it returns new keypoints in xyn format. Args: image_width (int): width of the image image_height (int): height of the image keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in xy format. Returns: keypoints_xyn (np.ndarray): 2D numpy array representing the keypoints in xyn format. """ keypoints_xyn = keypoints_xy.copy() keypoints_xyn[:, 0] = keypoints_xy[:, 0] / image_width keypoints_xyn[:, 1] = keypoints_xy[:, 1] / image_height return keypoints_xyn def show_keypoints_xy( array_image: np.ndarray, keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str], verbose: bool = True, ) -> None: """ Show keypoints on top of an `array_image`, useful in jupyter notebooks for instance. Args: array_image (np.ndarray): numpy array representing an image. keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in xy format. classes_dictionnary (dict[int, str]): Model prediction classes. verbose (bool): should we make the image verbose by adding some label for each keypoint? """ colors = ["r", "g", "b", "c", "m", "y", "w"] plt.imshow(array_image) label_margin = 20 height, width, _ = array_image.shape for class_inst, class_name in classes_dictionnary.items(): color = colors[class_inst] x, y = keypoints_xy[class_inst] plt.scatter(x=[x], y=[y], c=color) if verbose: plt.annotate(class_name, (x - label_margin, y - label_margin), c="w") def draw_keypoints_xy_on_ax( ax, array_image: np.ndarray, keypoints_xy: np.ndarray, classes_dictionnary: dict, verbose: bool = True, ) -> None: """ Dray keypoints on top of an `array_image`, useful in jupyter notebooks for instance. Args: array_image (np.ndarray): numpy array representing an image. keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in xy format. classes_dictionnary (dict[int, str]): Model prediction classes. verbose (bool): should we make the image verbose by adding some label for each keypoint? """ colors = ["r", "g", "b", "c", "m", "y", "w"] ax.imshow(array_image) label_margin = 20 height, width, _ = array_image.shape for class_inst, class_name in classes_dictionnary.items(): color = colors[class_inst] x, y = keypoints_xy[class_inst] ax.scatter(x=[x], y=[y], c=color) if verbose: ax.annotate(class_name, (x - label_margin, y - label_margin), c="w") k_pelvic_fin_base = get_keypoint( class_name="pelvic_fin_base", keypoints=keypoints_xy, classes_dictionnary=classes_dictionnary, ) k_anal_fin_base = get_keypoint( class_name="anal_fin_base", keypoints=keypoints_xy, classes_dictionnary=classes_dictionnary, ) ax.axline(k_pelvic_fin_base, k_anal_fin_base, c="lime") def show_keypoints_xyn( array_image: np.ndarray, keypoints_xyn: np.ndarray, classes_dictionnary: dict, verbose: bool = True, ) -> None: """ Dray keypoints on top of an `array_image`, useful in jupyter notebooks for instance. Args: array_image (np.ndarray): numpy array representing an image. keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in xy format. classes_dictionnary (dict[int, str]): Model prediction classes. verbose (bool): should we make the image verbose by adding some label for each keypoint? """ height, width, _ = array_image.shape keypoints_xy = scale_keypoints_to_image_size( image_height=height, image_width=width, keypoints_xyn=keypoints_xyn, ) show_keypoints_xy( array_image=array_image, keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary, verbose=verbose, ) def rotate_point( clockwise: bool, origin: Tuple[float, float], point: Tuple[float, float], angle: float, ) -> Tuple[float, float]: """ Rotate a point clockwise or counterclockwise by a given angle around a given origin. Args: clockwise (bool): should the rotation be clockwise? origin (Tuple[float, float]): origin 2D point to perform the rotation. point (Tuple[float, float]): 2D point to rotate. angle (float): angle in radian. Returns: rotated_point (Tuple[float, float]): rotated point after applying the 2D transformation. """ if clockwise: angle = 0 - angle ox, oy = origin px, py = point qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy) qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) return qx, qy def rotate_image(angle_rad: float, array_image: np.ndarray, expand=False) -> np.ndarray: """ Rotate an `array_image` by an angle defined in radians, clockwise using the center as origin. Args: angle_rad (float): angle in radian. array_image (np.ndarray): numpy array representing the image to rotate. expand (bool): should we expand the image as we rotate it to not truncate some parts of it if the image is not square? """ angle_degrees = math.degrees(angle_rad) return np.array(Image.fromarray(array_image).rotate(angle_degrees, expand=expand)) def rotate_keypoints_xy( angle_rad: float, keypoints_xy: np.ndarray, origin: Tuple[float, float], clockwise: bool = True, ) -> np.ndarray: """ Rotate keypoints by an angle defined in radians, clockwise or counterclockwise using the `origin_xyn` point. Args: angle_rad (float): angle in radian. origin (Tuple[float, float]): origin 2D point to perform the rotation. keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in xy format. clockwise (bool): should the rotation be clockwise? Returns: rotated_keypoints_xy (np.ndarray): rotated keypoints in xy format. """ return np.array( [ rotate_point( clockwise=clockwise, origin=origin, point=(kp[0].item(), kp[1].item()), angle=angle_rad, ) for kp in keypoints_xy ] ) def rotate_image_and_keypoints_xy( angle_rad: float, array_image: np.ndarray, keypoints_xy: np.ndarray, ) -> dict[str, np.ndarray]: """ Rotate the image and its keypoints provided the parameters. Args: angle_rad (float): angle in radian. array_image (np.ndarray): numpy array representing the image to rotate. keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in xy format. Returns: array_image (np.ndarray): rotated array_image as a 2D numpy array. keypoints_xy (np.ndarray): rotated keypoints in xy format. """ height, width, _ = array_image.shape center_x, center_y = int(width / 2), int(height / 2) origin = (center_x, center_y) image_rotated = rotate_image(angle_rad=angle_rad, array_image=array_image) keypoints_xy_rotated = rotate_keypoints_xy( angle_rad=angle_rad, keypoints_xy=keypoints_xy, origin=origin, clockwise=True ) return { "array_image": image_rotated, "keypoints_xy": keypoints_xy_rotated, } def get_keypoint( class_name: str, keypoints: np.ndarray, classes_dictionnary: dict[int, str], ) -> np.ndarray: """ Return the keypoint for the provided `class_name` (eg. eye, front_fin_base, etc). Raises: AssertionError: when the provided class_name is not compatible or when the number of keypoints does not match. """ assert ( class_name in classes_dictionnary.values() ), f"class_name should be in {classes_dictionnary.values()}" assert len(classes_dictionnary) == len( keypoints ), "Number of provided keypoints does not match the number of class names" class_name_to_class_inst = {v: k for k, v in classes_dictionnary.items()} return keypoints[class_name_to_class_inst[class_name]] def to_direction_vector(p1: np.ndarray, p2: np.ndarray) -> np.ndarray: """ Return the direction vector between two points p1 and p2. """ assert len(p1) == len(p2), "p1 and p2 should have the same length" return p2 - p1 def is_upside_down( keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str], ) -> bool: """ Is the fish upside down? """ k_pelvic_fin_base = get_keypoint( class_name="pelvic_fin_base", keypoints=keypoints_xy, classes_dictionnary=classes_dictionnary, ) k_anal_fin_base = get_keypoint( class_name="anal_fin_base", keypoints=keypoints_xy, classes_dictionnary=classes_dictionnary, ) k_dorsal_fin_base = get_keypoint( class_name="dorsal_fin_base", keypoints=keypoints_xy, classes_dictionnary=classes_dictionnary, ) print(f"dorsal_fin_base: {k_dorsal_fin_base}") print(f"pelvic_fin_base: {k_pelvic_fin_base}") print(f"anal_fin_base: {k_anal_fin_base}") return (k_dorsal_fin_base[1] > k_pelvic_fin_base[1]).item() def get_direction_vector( keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str] ) -> np.ndarray: """ Get the direction vector for the realignment. """ # Align horizontally the fish based on its pelvic fin base and its anal fin base k_pelvic_fin_base = get_keypoint( class_name="pelvic_fin_base", keypoints=keypoints_xy, classes_dictionnary=classes_dictionnary, ) k_anal_fin_base = get_keypoint( class_name="anal_fin_base", keypoints=keypoints_xy, classes_dictionnary=classes_dictionnary, ) return to_direction_vector( p1=k_pelvic_fin_base, p2=k_anal_fin_base ) # line between the pelvic and anal fins def get_reference_vector() -> np.ndarray: """ Get the reference vector to align the direction vector to. """ return np.array([1, 0]) # horizontal axis def get_angle(v1: np.ndarray, v2: np.ndarray) -> float: """ Return the angle (couterclockwise) in radians between vectors v1 and v2. """ cos_theta = ( np.dot(v1, v2) / np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) ).item() return -math.acos(cos_theta) def is_aligned(keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str]) -> bool: """ Return wether the keypoints are now properly aligned with the direction vector used to make the rotation. """ v1 = get_direction_vector( keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary ) v_ref = get_reference_vector() theta = get_angle(v1, v_ref) return abs(theta) <= 0.001 def get_angle_correction_sign( angle_rad: float, array_image: np.ndarray, keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str], ) -> int: """ Returns 1 or -1 depending on the angle sign to set. """ rotation_results = rotate_image_and_keypoints_xy( angle_rad=angle_rad, array_image=array_image, keypoints_xy=keypoints_xy ) if not is_aligned( keypoints_xy=rotation_results["keypoints_xy"], classes_dictionnary=classes_dictionnary, ): return -1 else: return 1 def get_angle_correction( keypoints_xy: np.ndarray, array_image: np.ndarray, classes_dictionnary: dict[int, str], ) -> float: """ Get the angle correction in radians that aligns the fish (based on the keypoints) horizontally. """ v1 = get_direction_vector( keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary ) v_ref = get_reference_vector() theta = get_angle(v1, v_ref) angle_sign = get_angle_correction_sign( angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary, ) theta = angle_sign * theta rotation_results = rotate_image_and_keypoints_xy( angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy ) # Check whether the fish is upside down if is_upside_down( keypoints_xy=rotation_results["keypoints_xy"], classes_dictionnary=classes_dictionnary, ): print("the fish is upside down...") return theta + math.pi else: print("The fish is not upside down") return theta # No need to rotate the fish more def show_algorithm_steps( image_filepath: Path, keypoints_xy: np.ndarray, rotation_results: dict, theta: float, classes_dictionnary: dict, ) -> None: """ Display a matplotlib figure that details step by step the result of the rotation. Keypoints can be overlayed with the images. """ array_image = np.array(Image.open(image_filepath)) array_image_final = np.array( Image.open(image_filepath).rotate(math.degrees(theta), expand=True) ) fig, axs = plt.subplots(1, 4, figsize=(20, 4)) fig.suptitle(f"{image_filepath.name}") print(f"image_filepath: {image_filepath}") # Hiding the x and y axis ticks for ax in axs: ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) axs[0].set_title("original") axs[0].imshow(array_image) axs[1].set_title("predicted keypoints") draw_keypoints_xy_on_ax( ax=axs[1], array_image=array_image, keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary, ) axs[2].set_title(f"rotation of {math.degrees(theta):.1f} degrees") draw_keypoints_xy_on_ax( ax=axs[2], array_image=rotation_results["array_image"], keypoints_xy=rotation_results["keypoints_xy"], classes_dictionnary=classes_dictionnary, ) axs[3].set_title("final") axs[3].imshow(array_image_final)