""" Module to manage the segmentation YOLO model. """ from pathlib import Path from typing import Any import cv2 import numpy as np from PIL import Image from ultralytics import YOLO import yolo def load_pretrained_model(model_str: str) -> YOLO: """ Load the pretrained model. """ return yolo.load_pretrained_model(model_str) def train( model: YOLO, data_yaml_path: Path, params: dict, project: Path = Path("data/04_models/yolo/"), experiment_name: str = "train", ): """Main function for running a train run. It saves the results under `project / experiment_name`. Args: model (YOLO): result of `load_pretrained_model`. data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters. project (Path): root path to store the run artifacts and results. experiment_name (str): name of the experiment, that is added to the project root path to store the run. """ return yolo.train( model=model, data_yaml_path=data_yaml_path, params=params, project=project, experiment_name=experiment_name, ) def predict(model: YOLO, pil_image: Image.Image) -> dict[str, Any]: """ Given a loaded model an a PIL image, it returns a map containing the segmentation predictions. Args: model (YOLO): loaded YOLO model for segmentation. pil_image (PIL): image to run the model on. Returns: prediction: Raw prediction from the model. orig_image: original image used for inference after preprocessing stages applied by ultralytics. mask (PIL): postprocessed mask in white and black format - used for visualization mask_raw (np.ndarray): Raw mask not postprocessed masked (PIL): mask applied to the pil_image. """ predictions = model(pil_image) mask_raw = predictions[0].masks[0].data.cpu().numpy().transpose(1, 2, 0).squeeze() # Convert single channel grayscale to 3 channel image mask_3channel = cv2.merge((mask_raw, mask_raw, mask_raw)) # Get the size of the original image (height, width, channels) h2, w2, c2 = predictions[0].orig_img.shape # Resize the mask to the same size as the image (can probably be removed if image is the same size as the model) mask = cv2.resize(mask_3channel, (w2, h2)) # Convert BGR to HSV hsv = cv2.cvtColor(mask, cv2.COLOR_BGR2HSV) # Define range of brightness in HSV lower_black = np.array([0, 0, 0]) upper_black = np.array([0, 0, 1]) # Create a mask. Threshold the HSV image to get everything black mask = cv2.inRange(mask, lower_black, upper_black) # Invert the mask to get everything but black mask = cv2.bitwise_not(mask) # Apply the mask to the original image masked = cv2.bitwise_and( predictions[0].orig_img, predictions[0].orig_img, mask=mask, ) # bgr to rgb and PIL conversion image_output2 = Image.fromarray(masked[:, :, ::-1]) # return Image.fromarray(mask), image_output2 return { "prediction": predictions[0], "mask": Image.fromarray(mask), "mask_raw": mask_raw, "masked": Image.fromarray(masked[:, :, ::-1]), }