trout-reID / segmentation.py
achouffe's picture
feat: initial commit
641857b verified
"""
Module to manage the segmentation YOLO model.
"""
from pathlib import Path
from typing import Any
import cv2
import numpy as np
from PIL import Image
from ultralytics import YOLO
import yolo
def load_pretrained_model(model_str: str) -> YOLO:
"""
Load the pretrained model.
"""
return yolo.load_pretrained_model(model_str)
def train(
model: YOLO,
data_yaml_path: Path,
params: dict,
project: Path = Path("data/04_models/yolo/"),
experiment_name: str = "train",
):
"""Main function for running a train run. It saves the results
under `project / experiment_name`.
Args:
model (YOLO): result of `load_pretrained_model`.
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
project (Path): root path to store the run artifacts and results.
experiment_name (str): name of the experiment, that is added to the project root path to store the run.
"""
return yolo.train(
model=model,
data_yaml_path=data_yaml_path,
params=params,
project=project,
experiment_name=experiment_name,
)
def predict(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
"""
Given a loaded model an a PIL image, it returns a map
containing the segmentation predictions.
Args:
model (YOLO): loaded YOLO model for segmentation.
pil_image (PIL): image to run the model on.
Returns:
prediction: Raw prediction from the model.
orig_image: original image used for inference
after preprocessing stages applied by
ultralytics.
mask (PIL): postprocessed mask in white and black format - used for visualization
mask_raw (np.ndarray): Raw mask not postprocessed
masked (PIL): mask applied to the pil_image.
"""
predictions = model(pil_image)
mask_raw = predictions[0].masks[0].data.cpu().numpy().transpose(1, 2, 0).squeeze()
# Convert single channel grayscale to 3 channel image
mask_3channel = cv2.merge((mask_raw, mask_raw, mask_raw))
# Get the size of the original image (height, width, channels)
h2, w2, c2 = predictions[0].orig_img.shape
# Resize the mask to the same size as the image (can probably be removed if image is the same size as the model)
mask = cv2.resize(mask_3channel, (w2, h2))
# Convert BGR to HSV
hsv = cv2.cvtColor(mask, cv2.COLOR_BGR2HSV)
# Define range of brightness in HSV
lower_black = np.array([0, 0, 0])
upper_black = np.array([0, 0, 1])
# Create a mask. Threshold the HSV image to get everything black
mask = cv2.inRange(mask, lower_black, upper_black)
# Invert the mask to get everything but black
mask = cv2.bitwise_not(mask)
# Apply the mask to the original image
masked = cv2.bitwise_and(
predictions[0].orig_img,
predictions[0].orig_img,
mask=mask,
)
# bgr to rgb and PIL conversion
image_output2 = Image.fromarray(masked[:, :, ::-1])
# return Image.fromarray(mask), image_output2
return {
"prediction": predictions[0],
"mask": Image.fromarray(mask),
"mask_raw": mask_raw,
"masked": Image.fromarray(masked[:, :, ::-1]),
}