Spaces:
Running
Running
from pathlib import Path | |
from typing import Any | |
import numpy as np | |
import torch | |
from PIL import Image | |
from ultralytics import YOLO | |
import identification | |
import pose | |
import segmentation | |
from identification import IdentificationModel | |
from utils import ( | |
PictureLayout, | |
crop, | |
get_picture_layout, | |
get_segmentation_mask_crop_box, | |
) | |
def load_pose_and_segmentation_models( | |
filepath_weights_segmentation_model: Path, | |
filepath_weights_pose_model: Path, | |
) -> dict[str, YOLO]: | |
""" | |
Load into memory the models used by the pipeline. | |
Returns: | |
segmentation (YOLO): segmentation model. | |
pose (YOLO): pose estimation model. | |
""" | |
model_segmentation = segmentation.load_pretrained_model( | |
str(filepath_weights_segmentation_model) | |
) | |
model_pose = pose.load_pretrained_model(str(filepath_weights_pose_model)) | |
return { | |
"segmentation": model_segmentation, | |
"pose": model_pose, | |
} | |
def load_models( | |
filepath_weights_segmentation_model: Path, | |
filepath_weights_pose_model: Path, | |
device: torch.device, | |
filepath_identification_lightglue_features: Path, | |
filepath_identification_db: Path, | |
extractor_type: str, | |
n_keypoints: int, | |
threshold_wasserstein: float, | |
) -> dict[str, YOLO | IdentificationModel]: | |
""" | |
Load into memory the models used by the pipeline. | |
Returns: | |
segmentation (YOLO): segmentation model. | |
pose (YOLO): pose estimation model. | |
identification (IdentificationModel): identification model. | |
""" | |
loaded_pose_seg_models = load_pose_and_segmentation_models( | |
filepath_weights_segmentation_model=filepath_weights_segmentation_model, | |
filepath_weights_pose_model=filepath_weights_pose_model, | |
) | |
model_identification = identification.load( | |
device=device, | |
filepath_features=filepath_identification_lightglue_features, | |
filepath_db=filepath_identification_db, | |
n_keypoints=n_keypoints, | |
extractor_type=extractor_type, | |
threshold_wasserstein=threshold_wasserstein, | |
) | |
return {**loaded_pose_seg_models, "identification": model_identification} | |
def run_preprocess(pil_image: Image.Image) -> dict[str, Any]: | |
""" | |
Run the preprocess stage of the pipeline. | |
Args: | |
pil_image (PIL): original image. | |
Returns: | |
pil_image (PIL): rotated image to make it a landscape. | |
layout (PictureLayout): layout type of the input image. | |
""" | |
picture_layout = get_picture_layout(pil_image=pil_image) | |
# If the image is in Portrait Mode, we turn it into Landscape | |
pil_image_preprocessed = ( | |
pil_image.rotate(angle=90, expand=True) | |
if picture_layout == PictureLayout.PORTRAIT | |
else pil_image | |
) | |
return { | |
"pil_image": pil_image_preprocessed, | |
"layout": picture_layout, | |
} | |
def run_pose(model: YOLO, pil_image: Image.Image) -> dict[str, Any]: | |
""" | |
Run the pose stage of the pipeline. | |
Args: | |
model (YOLO): loaded pose estimation model. | |
pil_image (PIL): Image to run the model on. | |
Returns: | |
prediction: Raw prediction from the model. | |
orig_image: original image used for inference after the preprocessing | |
stages applied by ultralytics. | |
keypoints_xy (np.ndarray): keypoints in xy format. | |
keypoints_xyn (np.ndarray): keyoints in xyn format. | |
theta (float): angle in radians to rotate the image to re-align it | |
horizontally. | |
side (FishSide): Predicted side of the fish. | |
""" | |
return pose.predict(model=model, pil_image=pil_image) | |
def run_crop( | |
pil_image_mask: Image.Image, | |
pil_image_masked: Image.Image, | |
padding: int, | |
) -> dict[str, Any]: | |
""" | |
Run the crop on the mask and masked images. | |
Args: | |
pil_image_mask (PIL): Image containing the segmentation mask. | |
pil_image_masked (PIL): Image containing the applied pil_image_mask on | |
the original image. | |
padding (int): by how much do we want to pad the result image? | |
Returns: | |
box (Tuple[int, int, int, int]): 4 tuple representing a rectangle (x1, | |
y1, x2, y2) with the upper left corner given first. | |
pil_image (PIL): cropped masked image. | |
""" | |
box_crop = get_segmentation_mask_crop_box( | |
pil_image_mask=pil_image_mask, | |
padding=padding, | |
) | |
pil_image_masked_cropped = crop( | |
pil_image=pil_image_masked, | |
box=box_crop, | |
) | |
return { | |
"box": box_crop, | |
"pil_image": pil_image_masked_cropped, | |
} | |
def run_rotation( | |
pil_image: Image.Image, | |
angle_rad: float, | |
keypoints_xy: np.ndarray, | |
) -> dict[str, Any]: | |
""" | |
Run the rotation stage of the pipeline. | |
Args: | |
pil_image (PIL): image to run the rotation on. | |
angle_rad (float): angle in radian to rotate the image. | |
keypoints_xy (np.ndarray): keypoints from the pose estimation | |
prediction in xy format. | |
Returns: | |
array_image (np.ndarray): rotated array_image as a 2D numpy array. | |
keypoints_xy (np.ndarray): rotated keypoints in xy format. | |
pil_image (PIL): rotated PIL image. | |
""" | |
results_rotation = pose.rotate_image_and_keypoints_xy( | |
angle_rad=angle_rad, | |
array_image=np.array(pil_image), | |
keypoints_xy=keypoints_xy, | |
) | |
pil_image_rotated = Image.fromarray(results_rotation["array_image"]) | |
return { | |
"pil_image": pil_image_rotated, | |
"array_image": results_rotation["array_image"], | |
"keypoints_xy": results_rotation["keypoints_xy"], | |
} | |
def run_segmentation(model: YOLO, pil_image: Image.Image) -> dict[str, Any]: | |
""" | |
Run the segmentation stage of the pipeline. | |
Args: | |
pil_image (PIL): image to run the rotation on. | |
model (YOLO): segmentation model. | |
prediction in xy format. | |
Returns: | |
prediction: Raw prediction from the model. | |
orig_image: original image used for inference | |
after preprocessing stages applied by | |
ultralytics. | |
mask (PIL): postprocessed mask in white and black format - used for visualization | |
mask_raw (np.ndarray): Raw mask not postprocessed | |
masked (PIL): mask applied to the pil_image. | |
""" | |
results_segmentation = segmentation.predict( | |
model=model, | |
pil_image=pil_image, | |
) | |
return results_segmentation | |
def run_pre_identification_stages( | |
loaded_models: dict[str, YOLO], | |
pil_image: Image.Image, | |
param_crop_padding: int = 0, | |
) -> dict[str, Any]: | |
""" | |
Run the partial ML pipeline on `pil_image` up to identifying the fish. It | |
prepares the input image `pil_image` to make it possible to identify it. | |
Args: | |
loaded_models (dict[str, YOLO]): resut of calling `load_models`. | |
pil_image (PIL): Image to run the pipeline on. | |
param_crop_padding (int): how much to pad the resulting segmentated | |
image when cropped. | |
Returns: | |
order (list[str]): the stages and their order. | |
stages (dict[str, Any]): the description of each stage, its | |
input and output. | |
""" | |
# Unpacking the loaded models | |
model_pose = loaded_models["pose"] | |
model_segmentation = loaded_models["segmentation"] | |
# Stage: Preprocess | |
results_preprocess = run_preprocess(pil_image=pil_image) | |
# Stage: Pose estimation | |
pil_image_preprocessed = results_preprocess["pil_image"] | |
results_pose = run_pose(model=model_pose, pil_image=pil_image_preprocessed) | |
# Stage: Rotation | |
results_rotation = run_rotation( | |
pil_image=pil_image_preprocessed, | |
keypoints_xy=results_pose["keypoints_xy"], | |
angle_rad=results_pose["theta"], | |
) | |
# Stage: Segmentation | |
pil_image_rotated = Image.fromarray(results_rotation["array_image"]) | |
results_segmentation = run_segmentation( | |
model=model_segmentation, pil_image=pil_image_rotated | |
) | |
# Stage: Crop | |
results_crop = run_crop( | |
pil_image_mask=results_segmentation["mask"], | |
pil_image_masked=results_segmentation["masked"], | |
padding=param_crop_padding, | |
) | |
return { | |
"order": [ | |
"preprocess", | |
"pose", | |
"rotation", | |
"segmentation", | |
"crop", | |
], | |
"stages": { | |
"preprocess": { | |
"input": {"pil_image": pil_image}, | |
"output": results_preprocess, | |
}, | |
"pose": { | |
"input": {"pil_image": pil_image_preprocessed}, | |
"output": results_pose, | |
}, | |
"rotation": { | |
"input": { | |
"pil_image": pil_image_preprocessed, | |
"angle_rad": results_pose["theta"], | |
"keypoints_xy": results_pose["keypoints_xy"], | |
}, | |
"output": results_rotation, | |
}, | |
"segmentation": { | |
"input": {"pil_image": pil_image_rotated}, | |
"output": results_segmentation, | |
}, | |
"crop": { | |
"input": { | |
"pil_image_mask": results_segmentation["mask"], | |
"pil_image_masked": results_segmentation["masked"], | |
"padding": param_crop_padding, | |
}, | |
"output": results_crop, | |
}, | |
}, | |
} | |
def run( | |
loaded_models: dict[str, YOLO | IdentificationModel], | |
pil_image: Image.Image, | |
param_crop_padding: int = 0, | |
param_k: int = 3, | |
) -> dict[str, Any]: | |
""" | |
Run the ML pipeline on `pil_image`. | |
Args: | |
loaded_models (dict[str, YOLO]): resut of calling `load_models`. | |
pil_image (PIL): Image to run the pipeline on. | |
param_crop_padding (int): how much to pad the resulting segmentated | |
image when cropped. | |
param_k (int): top k matches to return. | |
Returns: | |
order (list[str]): the stages and their order. | |
stages (dict[str, Any]): the description of each stage, its | |
input and output. | |
""" | |
model_identification = loaded_models["identification"] | |
results = run_pre_identification_stages( | |
loaded_models=loaded_models, | |
pil_image=pil_image, | |
param_crop_padding=param_crop_padding, | |
) | |
results_crop = results["stages"]["crop"]["output"] | |
results_identification = identification.predict( | |
model=model_identification, | |
pil_image=results_crop["pil_image"], | |
k=param_k, | |
) | |
results["order"].append("identification") | |
results["stages"]["identification"] = { | |
"input": {"pil_image": results_crop["pil_image"]}, | |
"output": results_identification, | |
} | |
return results | |