Spaces:
Running
Running
""" | |
Module to manage the pose detection model. | |
""" | |
from enum import Enum | |
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
from PIL import Image | |
from ultralytics import YOLO | |
import yolo | |
from utils import get_angle_correction, get_keypoint, rotate_image_and_keypoints_xy | |
class FishSide(Enum): | |
""" | |
Represents the Side of the Fish. | |
""" | |
RIGHT = "right" | |
LEFT = "left" | |
def predict_fish_side( | |
array_image: np.ndarray, | |
keypoints_xy: np.ndarray, | |
classes_dictionnary: dict[int, str], | |
) -> FishSide: | |
""" | |
Predict which side of the fish is displayed on the image. | |
Args: | |
array_image (np.ndarray): numpy array representing the image. | |
keypoints_xy (np.ndarray): detected keypoints on array_image in | |
xy format. | |
classes_dictionnary (dict[int, str]): mapping of class instance | |
to. | |
Returns: | |
FishSide: Predicted side of the fish. | |
""" | |
theta = get_angle_correction( | |
keypoints_xy=keypoints_xy, | |
array_image=array_image, | |
classes_dictionnary=classes_dictionnary, | |
) | |
rotation_results = rotate_image_and_keypoints_xy( | |
angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy | |
) | |
# We check if the eyes is on the left/right of one of the fins. | |
k_eye = get_keypoint( | |
class_name="eye", | |
keypoints=rotation_results["keypoints_xy"], | |
classes_dictionnary=classes_dictionnary, | |
) | |
k_anal_fin_base = get_keypoint( | |
class_name="anal_fin_base", | |
keypoints=rotation_results["keypoints_xy"], | |
classes_dictionnary=classes_dictionnary, | |
) | |
if k_eye[0] <= k_anal_fin_base[0]: | |
return FishSide.LEFT | |
else: | |
return FishSide.RIGHT | |
# Model prediction classes | |
CLASSES_DICTIONNARY = { | |
0: "eye", | |
1: "front_fin_base", | |
2: "tail_bottom_tip", | |
3: "tail_top_tip", | |
4: "dorsal_fin_base", | |
5: "pelvic_fin_base", | |
6: "anal_fin_base", | |
} | |
def load_pretrained_model(model_str: str) -> YOLO: | |
""" | |
Load the pretrained model. | |
""" | |
return yolo.load_pretrained_model(model_str) | |
def train( | |
model: YOLO, | |
data_yaml_path: Path, | |
params: dict, | |
project: Path = Path("data/04_models/yolo/"), | |
experiment_name: str = "train", | |
): | |
"""Main function for running a train run. It saves the results | |
under `project / experiment_name`. | |
Args: | |
model (YOLO): result of `load_pretrained_model`. | |
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on | |
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters. | |
project (Path): root path to store the run artifacts and results. | |
experiment_name (str): name of the experiment, that is added to the project root path to store the run. | |
""" | |
return yolo.train( | |
model=model, | |
data_yaml_path=data_yaml_path, | |
params=params, | |
project=project, | |
experiment_name=experiment_name, | |
) | |
def predict( | |
model: YOLO, | |
pil_image: Image.Image, | |
classes_dictionnary: dict[int, str] = CLASSES_DICTIONNARY, | |
) -> dict[str, Any]: | |
""" | |
Given a loaded model and a PIL image, it returns a map containing the | |
keypoints predictions. | |
Args: | |
model (YOLO): loaded YOLO model for pose estimation. | |
pil_image (PIL): image to run the model on. | |
classes_dictionnary (dict[int, str]): mapping of class instance to | |
class name. | |
Returns: | |
prediction: Raw prediction from the model. | |
orig_image: original image used for inference after the preprocessing | |
stages applied by ultralytics. | |
keypoints_xy (np.ndarray): keypoints in xy format. | |
keypoints_xyn (np.ndarray): keyoints in xyn format. | |
theta (float): angle in radians to rotate the image to re-align it | |
horizontally. | |
side (FishSide): Predicted side of the fish. | |
""" | |
predictions = model(pil_image) | |
print(predictions) | |
orig_image = predictions[0].orig_img | |
keypoints_xy = predictions[0].keypoints.xy.cpu().numpy().squeeze() | |
theta = get_angle_correction( | |
keypoints_xy=keypoints_xy, | |
array_image=orig_image, | |
classes_dictionnary=classes_dictionnary, | |
) | |
side = predict_fish_side( | |
array_image=orig_image, | |
keypoints_xy=keypoints_xy, | |
classes_dictionnary=classes_dictionnary, | |
) | |
return { | |
"prediction": predictions[0], | |
"orig_image": orig_image, | |
"keypoints_xy": keypoints_xy, | |
"keypoints_xyn": predictions[0].keypoints.xyn.cpu().numpy().squeeze(), | |
"theta": theta, | |
"side": side, | |
} | |