""" Module to manage the pose detection model. """ from enum import Enum from pathlib import Path from typing import Any import numpy as np from PIL import Image from ultralytics import YOLO import yolo from utils import get_angle_correction, get_keypoint, rotate_image_and_keypoints_xy class FishSide(Enum): """ Represents the Side of the Fish. """ RIGHT = "right" LEFT = "left" def predict_fish_side( array_image: np.ndarray, keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str], ) -> FishSide: """ Predict which side of the fish is displayed on the image. Args: array_image (np.ndarray): numpy array representing the image. keypoints_xy (np.ndarray): detected keypoints on array_image in xy format. classes_dictionnary (dict[int, str]): mapping of class instance to. Returns: FishSide: Predicted side of the fish. """ theta = get_angle_correction( keypoints_xy=keypoints_xy, array_image=array_image, classes_dictionnary=classes_dictionnary, ) rotation_results = rotate_image_and_keypoints_xy( angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy ) # We check if the eyes is on the left/right of one of the fins. k_eye = get_keypoint( class_name="eye", keypoints=rotation_results["keypoints_xy"], classes_dictionnary=classes_dictionnary, ) k_anal_fin_base = get_keypoint( class_name="anal_fin_base", keypoints=rotation_results["keypoints_xy"], classes_dictionnary=classes_dictionnary, ) if k_eye[0] <= k_anal_fin_base[0]: return FishSide.LEFT else: return FishSide.RIGHT # Model prediction classes CLASSES_DICTIONNARY = { 0: "eye", 1: "front_fin_base", 2: "tail_bottom_tip", 3: "tail_top_tip", 4: "dorsal_fin_base", 5: "pelvic_fin_base", 6: "anal_fin_base", } def load_pretrained_model(model_str: str) -> YOLO: """ Load the pretrained model. """ return yolo.load_pretrained_model(model_str) def train( model: YOLO, data_yaml_path: Path, params: dict, project: Path = Path("data/04_models/yolo/"), experiment_name: str = "train", ): """Main function for running a train run. It saves the results under `project / experiment_name`. Args: model (YOLO): result of `load_pretrained_model`. data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters. project (Path): root path to store the run artifacts and results. experiment_name (str): name of the experiment, that is added to the project root path to store the run. """ return yolo.train( model=model, data_yaml_path=data_yaml_path, params=params, project=project, experiment_name=experiment_name, ) def predict( model: YOLO, pil_image: Image.Image, classes_dictionnary: dict[int, str] = CLASSES_DICTIONNARY, ) -> dict[str, Any]: """ Given a loaded model and a PIL image, it returns a map containing the keypoints predictions. Args: model (YOLO): loaded YOLO model for pose estimation. pil_image (PIL): image to run the model on. classes_dictionnary (dict[int, str]): mapping of class instance to class name. Returns: prediction: Raw prediction from the model. orig_image: original image used for inference after the preprocessing stages applied by ultralytics. keypoints_xy (np.ndarray): keypoints in xy format. keypoints_xyn (np.ndarray): keyoints in xyn format. theta (float): angle in radians to rotate the image to re-align it horizontally. side (FishSide): Predicted side of the fish. """ predictions = model(pil_image) print(predictions) orig_image = predictions[0].orig_img keypoints_xy = predictions[0].keypoints.xy.cpu().numpy().squeeze() theta = get_angle_correction( keypoints_xy=keypoints_xy, array_image=orig_image, classes_dictionnary=classes_dictionnary, ) side = predict_fish_side( array_image=orig_image, keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary, ) return { "prediction": predictions[0], "orig_image": orig_image, "keypoints_xy": keypoints_xy, "keypoints_xyn": predictions[0].keypoints.xyn.cpu().numpy().squeeze(), "theta": theta, "side": side, }