trout-reID / pose.py
achouffe's picture
feat: initial commit
641857b verified
"""
Module to manage the pose detection model.
"""
from enum import Enum
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from ultralytics import YOLO
import yolo
from utils import get_angle_correction, get_keypoint, rotate_image_and_keypoints_xy
class FishSide(Enum):
"""
Represents the Side of the Fish.
"""
RIGHT = "right"
LEFT = "left"
def predict_fish_side(
array_image: np.ndarray,
keypoints_xy: np.ndarray,
classes_dictionnary: dict[int, str],
) -> FishSide:
"""
Predict which side of the fish is displayed on the image.
Args:
array_image (np.ndarray): numpy array representing the image.
keypoints_xy (np.ndarray): detected keypoints on array_image in
xy format.
classes_dictionnary (dict[int, str]): mapping of class instance
to.
Returns:
FishSide: Predicted side of the fish.
"""
theta = get_angle_correction(
keypoints_xy=keypoints_xy,
array_image=array_image,
classes_dictionnary=classes_dictionnary,
)
rotation_results = rotate_image_and_keypoints_xy(
angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy
)
# We check if the eyes is on the left/right of one of the fins.
k_eye = get_keypoint(
class_name="eye",
keypoints=rotation_results["keypoints_xy"],
classes_dictionnary=classes_dictionnary,
)
k_anal_fin_base = get_keypoint(
class_name="anal_fin_base",
keypoints=rotation_results["keypoints_xy"],
classes_dictionnary=classes_dictionnary,
)
if k_eye[0] <= k_anal_fin_base[0]:
return FishSide.LEFT
else:
return FishSide.RIGHT
# Model prediction classes
CLASSES_DICTIONNARY = {
0: "eye",
1: "front_fin_base",
2: "tail_bottom_tip",
3: "tail_top_tip",
4: "dorsal_fin_base",
5: "pelvic_fin_base",
6: "anal_fin_base",
}
def load_pretrained_model(model_str: str) -> YOLO:
"""
Load the pretrained model.
"""
return yolo.load_pretrained_model(model_str)
def train(
model: YOLO,
data_yaml_path: Path,
params: dict,
project: Path = Path("data/04_models/yolo/"),
experiment_name: str = "train",
):
"""Main function for running a train run. It saves the results
under `project / experiment_name`.
Args:
model (YOLO): result of `load_pretrained_model`.
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
project (Path): root path to store the run artifacts and results.
experiment_name (str): name of the experiment, that is added to the project root path to store the run.
"""
return yolo.train(
model=model,
data_yaml_path=data_yaml_path,
params=params,
project=project,
experiment_name=experiment_name,
)
def predict(
model: YOLO,
pil_image: Image.Image,
classes_dictionnary: dict[int, str] = CLASSES_DICTIONNARY,
) -> dict[str, Any]:
"""
Given a loaded model and a PIL image, it returns a map containing the
keypoints predictions.
Args:
model (YOLO): loaded YOLO model for pose estimation.
pil_image (PIL): image to run the model on.
classes_dictionnary (dict[int, str]): mapping of class instance to
class name.
Returns:
prediction: Raw prediction from the model.
orig_image: original image used for inference after the preprocessing
stages applied by ultralytics.
keypoints_xy (np.ndarray): keypoints in xy format.
keypoints_xyn (np.ndarray): keyoints in xyn format.
theta (float): angle in radians to rotate the image to re-align it
horizontally.
side (FishSide): Predicted side of the fish.
"""
predictions = model(pil_image)
print(predictions)
orig_image = predictions[0].orig_img
keypoints_xy = predictions[0].keypoints.xy.cpu().numpy().squeeze()
theta = get_angle_correction(
keypoints_xy=keypoints_xy,
array_image=orig_image,
classes_dictionnary=classes_dictionnary,
)
side = predict_fish_side(
array_image=orig_image,
keypoints_xy=keypoints_xy,
classes_dictionnary=classes_dictionnary,
)
return {
"prediction": predictions[0],
"orig_image": orig_image,
"keypoints_xy": keypoints_xy,
"keypoints_xyn": predictions[0].keypoints.xyn.cpu().numpy().squeeze(),
"theta": theta,
"side": side,
}