File size: 4,773 Bytes
641857b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Module to manage the pose detection model.
"""

from enum import Enum
from pathlib import Path
from typing import Any

import numpy as np
from PIL import Image
from ultralytics import YOLO

import yolo
from utils import get_angle_correction, get_keypoint, rotate_image_and_keypoints_xy


class FishSide(Enum):
    """
    Represents the Side of the Fish.
    """

    RIGHT = "right"
    LEFT = "left"


def predict_fish_side(
    array_image: np.ndarray,
    keypoints_xy: np.ndarray,
    classes_dictionnary: dict[int, str],
) -> FishSide:
    """
    Predict which side of the fish is displayed on the image.

    Args:
        array_image (np.ndarray): numpy array representing the image.
        keypoints_xy (np.ndarray): detected keypoints on array_image in
        xy format.
        classes_dictionnary (dict[int, str]): mapping of class instance
        to.

    Returns:
        FishSide: Predicted side of the fish.
    """

    theta = get_angle_correction(
        keypoints_xy=keypoints_xy,
        array_image=array_image,
        classes_dictionnary=classes_dictionnary,
    )
    rotation_results = rotate_image_and_keypoints_xy(
        angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy
    )

    # We check if the eyes is on the left/right of one of the fins.
    k_eye = get_keypoint(
        class_name="eye",
        keypoints=rotation_results["keypoints_xy"],
        classes_dictionnary=classes_dictionnary,
    )
    k_anal_fin_base = get_keypoint(
        class_name="anal_fin_base",
        keypoints=rotation_results["keypoints_xy"],
        classes_dictionnary=classes_dictionnary,
    )

    if k_eye[0] <= k_anal_fin_base[0]:
        return FishSide.LEFT
    else:
        return FishSide.RIGHT


# Model prediction classes
CLASSES_DICTIONNARY = {
    0: "eye",
    1: "front_fin_base",
    2: "tail_bottom_tip",
    3: "tail_top_tip",
    4: "dorsal_fin_base",
    5: "pelvic_fin_base",
    6: "anal_fin_base",
}


def load_pretrained_model(model_str: str) -> YOLO:
    """
    Load the pretrained model.
    """
    return yolo.load_pretrained_model(model_str)


def train(
    model: YOLO,
    data_yaml_path: Path,
    params: dict,
    project: Path = Path("data/04_models/yolo/"),
    experiment_name: str = "train",
):
    """Main function for running a train run. It saves the results
    under `project / experiment_name`.

    Args:
        model (YOLO): result of `load_pretrained_model`.
        data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
        params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
        project (Path): root path to store the run artifacts and results.
        experiment_name (str): name of the experiment, that is added to the project root path to store the run.
    """
    return yolo.train(
        model=model,
        data_yaml_path=data_yaml_path,
        params=params,
        project=project,
        experiment_name=experiment_name,
    )


def predict(
    model: YOLO,
    pil_image: Image.Image,
    classes_dictionnary: dict[int, str] = CLASSES_DICTIONNARY,
) -> dict[str, Any]:
    """
    Given a loaded model and a PIL image, it returns a map containing the
    keypoints predictions.

    Args:
        model (YOLO): loaded YOLO model for pose estimation.
        pil_image (PIL): image to run the model on.
        classes_dictionnary (dict[int, str]): mapping of class instance to
        class name.

    Returns:
        prediction: Raw prediction from the model.
        orig_image: original image used for inference after the preprocessing
        stages applied by ultralytics.
        keypoints_xy (np.ndarray): keypoints in xy format.
        keypoints_xyn (np.ndarray): keyoints in xyn format.
        theta (float): angle in radians to rotate the image to re-align it
        horizontally.
        side (FishSide): Predicted side of the fish.
    """
    predictions = model(pil_image)
    print(predictions)
    orig_image = predictions[0].orig_img
    keypoints_xy = predictions[0].keypoints.xy.cpu().numpy().squeeze()

    theta = get_angle_correction(
        keypoints_xy=keypoints_xy,
        array_image=orig_image,
        classes_dictionnary=classes_dictionnary,
    )

    side = predict_fish_side(
        array_image=orig_image,
        keypoints_xy=keypoints_xy,
        classes_dictionnary=classes_dictionnary,
    )

    return {
        "prediction": predictions[0],
        "orig_image": orig_image,
        "keypoints_xy": keypoints_xy,
        "keypoints_xyn": predictions[0].keypoints.xyn.cpu().numpy().squeeze(),
        "theta": theta,
        "side": side,
    }