Spaces:
Running
Running
File size: 4,773 Bytes
641857b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
Module to manage the pose detection model.
"""
from enum import Enum
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from ultralytics import YOLO
import yolo
from utils import get_angle_correction, get_keypoint, rotate_image_and_keypoints_xy
class FishSide(Enum):
"""
Represents the Side of the Fish.
"""
RIGHT = "right"
LEFT = "left"
def predict_fish_side(
array_image: np.ndarray,
keypoints_xy: np.ndarray,
classes_dictionnary: dict[int, str],
) -> FishSide:
"""
Predict which side of the fish is displayed on the image.
Args:
array_image (np.ndarray): numpy array representing the image.
keypoints_xy (np.ndarray): detected keypoints on array_image in
xy format.
classes_dictionnary (dict[int, str]): mapping of class instance
to.
Returns:
FishSide: Predicted side of the fish.
"""
theta = get_angle_correction(
keypoints_xy=keypoints_xy,
array_image=array_image,
classes_dictionnary=classes_dictionnary,
)
rotation_results = rotate_image_and_keypoints_xy(
angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy
)
# We check if the eyes is on the left/right of one of the fins.
k_eye = get_keypoint(
class_name="eye",
keypoints=rotation_results["keypoints_xy"],
classes_dictionnary=classes_dictionnary,
)
k_anal_fin_base = get_keypoint(
class_name="anal_fin_base",
keypoints=rotation_results["keypoints_xy"],
classes_dictionnary=classes_dictionnary,
)
if k_eye[0] <= k_anal_fin_base[0]:
return FishSide.LEFT
else:
return FishSide.RIGHT
# Model prediction classes
CLASSES_DICTIONNARY = {
0: "eye",
1: "front_fin_base",
2: "tail_bottom_tip",
3: "tail_top_tip",
4: "dorsal_fin_base",
5: "pelvic_fin_base",
6: "anal_fin_base",
}
def load_pretrained_model(model_str: str) -> YOLO:
"""
Load the pretrained model.
"""
return yolo.load_pretrained_model(model_str)
def train(
model: YOLO,
data_yaml_path: Path,
params: dict,
project: Path = Path("data/04_models/yolo/"),
experiment_name: str = "train",
):
"""Main function for running a train run. It saves the results
under `project / experiment_name`.
Args:
model (YOLO): result of `load_pretrained_model`.
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
project (Path): root path to store the run artifacts and results.
experiment_name (str): name of the experiment, that is added to the project root path to store the run.
"""
return yolo.train(
model=model,
data_yaml_path=data_yaml_path,
params=params,
project=project,
experiment_name=experiment_name,
)
def predict(
model: YOLO,
pil_image: Image.Image,
classes_dictionnary: dict[int, str] = CLASSES_DICTIONNARY,
) -> dict[str, Any]:
"""
Given a loaded model and a PIL image, it returns a map containing the
keypoints predictions.
Args:
model (YOLO): loaded YOLO model for pose estimation.
pil_image (PIL): image to run the model on.
classes_dictionnary (dict[int, str]): mapping of class instance to
class name.
Returns:
prediction: Raw prediction from the model.
orig_image: original image used for inference after the preprocessing
stages applied by ultralytics.
keypoints_xy (np.ndarray): keypoints in xy format.
keypoints_xyn (np.ndarray): keyoints in xyn format.
theta (float): angle in radians to rotate the image to re-align it
horizontally.
side (FishSide): Predicted side of the fish.
"""
predictions = model(pil_image)
print(predictions)
orig_image = predictions[0].orig_img
keypoints_xy = predictions[0].keypoints.xy.cpu().numpy().squeeze()
theta = get_angle_correction(
keypoints_xy=keypoints_xy,
array_image=orig_image,
classes_dictionnary=classes_dictionnary,
)
side = predict_fish_side(
array_image=orig_image,
keypoints_xy=keypoints_xy,
classes_dictionnary=classes_dictionnary,
)
return {
"prediction": predictions[0],
"orig_image": orig_image,
"keypoints_xy": keypoints_xy,
"keypoints_xyn": predictions[0].keypoints.xyn.cpu().numpy().squeeze(),
"theta": theta,
"side": side,
}
|