Spaces:
Running
Running
""" | |
Module to manage the segmentation YOLO model. | |
""" | |
from pathlib import Path | |
from typing import Any | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from ultralytics import YOLO | |
import yolo | |
def load_pretrained_model(model_str: str) -> YOLO: | |
""" | |
Load the pretrained model. | |
""" | |
return yolo.load_pretrained_model(model_str) | |
def train( | |
model: YOLO, | |
data_yaml_path: Path, | |
params: dict, | |
project: Path = Path("data/04_models/yolo/"), | |
experiment_name: str = "train", | |
): | |
"""Main function for running a train run. It saves the results | |
under `project / experiment_name`. | |
Args: | |
model (YOLO): result of `load_pretrained_model`. | |
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on | |
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters. | |
project (Path): root path to store the run artifacts and results. | |
experiment_name (str): name of the experiment, that is added to the project root path to store the run. | |
""" | |
return yolo.train( | |
model=model, | |
data_yaml_path=data_yaml_path, | |
params=params, | |
project=project, | |
experiment_name=experiment_name, | |
) | |
def predict(model: YOLO, pil_image: Image.Image) -> dict[str, Any]: | |
""" | |
Given a loaded model an a PIL image, it returns a map | |
containing the segmentation predictions. | |
Args: | |
model (YOLO): loaded YOLO model for segmentation. | |
pil_image (PIL): image to run the model on. | |
Returns: | |
prediction: Raw prediction from the model. | |
orig_image: original image used for inference | |
after preprocessing stages applied by | |
ultralytics. | |
mask (PIL): postprocessed mask in white and black format - used for visualization | |
mask_raw (np.ndarray): Raw mask not postprocessed | |
masked (PIL): mask applied to the pil_image. | |
""" | |
predictions = model(pil_image) | |
mask_raw = predictions[0].masks[0].data.cpu().numpy().transpose(1, 2, 0).squeeze() | |
# Convert single channel grayscale to 3 channel image | |
mask_3channel = cv2.merge((mask_raw, mask_raw, mask_raw)) | |
# Get the size of the original image (height, width, channels) | |
h2, w2, c2 = predictions[0].orig_img.shape | |
# Resize the mask to the same size as the image (can probably be removed if image is the same size as the model) | |
mask = cv2.resize(mask_3channel, (w2, h2)) | |
# Convert BGR to HSV | |
hsv = cv2.cvtColor(mask, cv2.COLOR_BGR2HSV) | |
# Define range of brightness in HSV | |
lower_black = np.array([0, 0, 0]) | |
upper_black = np.array([0, 0, 1]) | |
# Create a mask. Threshold the HSV image to get everything black | |
mask = cv2.inRange(mask, lower_black, upper_black) | |
# Invert the mask to get everything but black | |
mask = cv2.bitwise_not(mask) | |
# Apply the mask to the original image | |
masked = cv2.bitwise_and( | |
predictions[0].orig_img, | |
predictions[0].orig_img, | |
mask=mask, | |
) | |
# bgr to rgb and PIL conversion | |
image_output2 = Image.fromarray(masked[:, :, ::-1]) | |
# return Image.fromarray(mask), image_output2 | |
return { | |
"prediction": predictions[0], | |
"mask": Image.fromarray(mask), | |
"mask_raw": mask_raw, | |
"masked": Image.fromarray(masked[:, :, ::-1]), | |
} | |