File size: 3,441 Bytes
641857b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Module to manage the segmentation YOLO model.
"""

from pathlib import Path
from typing import Any

import cv2
import numpy as np
from PIL import Image
from ultralytics import YOLO

import yolo


def load_pretrained_model(model_str: str) -> YOLO:
    """
    Load the pretrained model.
    """
    return yolo.load_pretrained_model(model_str)


def train(
    model: YOLO,
    data_yaml_path: Path,
    params: dict,
    project: Path = Path("data/04_models/yolo/"),
    experiment_name: str = "train",
):
    """Main function for running a train run. It saves the results
    under `project / experiment_name`.

    Args:
        model (YOLO): result of `load_pretrained_model`.
        data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
        params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
        project (Path): root path to store the run artifacts and results.
        experiment_name (str): name of the experiment, that is added to the project root path to store the run.
    """
    return yolo.train(
        model=model,
        data_yaml_path=data_yaml_path,
        params=params,
        project=project,
        experiment_name=experiment_name,
    )


def predict(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
    """
    Given a loaded model an a PIL image, it returns a map
    containing the segmentation predictions.

    Args:
        model (YOLO): loaded YOLO model for segmentation.
        pil_image (PIL): image to run the model on.

    Returns:
        prediction: Raw prediction from the model.
        orig_image: original image used for inference
        after preprocessing stages applied by
        ultralytics.
        mask (PIL): postprocessed mask in white and black format - used for visualization
        mask_raw (np.ndarray): Raw mask not postprocessed
        masked (PIL): mask applied to the pil_image.
    """
    predictions = model(pil_image)
    mask_raw = predictions[0].masks[0].data.cpu().numpy().transpose(1, 2, 0).squeeze()
    # Convert single channel grayscale to 3 channel image
    mask_3channel = cv2.merge((mask_raw, mask_raw, mask_raw))
    # Get the size of the original image (height, width, channels)
    h2, w2, c2 = predictions[0].orig_img.shape
    # Resize the mask to the same size as the image (can probably be removed if image is the same size as the model)
    mask = cv2.resize(mask_3channel, (w2, h2))
    # Convert BGR to HSV
    hsv = cv2.cvtColor(mask, cv2.COLOR_BGR2HSV)

    # Define range of brightness in HSV
    lower_black = np.array([0, 0, 0])
    upper_black = np.array([0, 0, 1])

    # Create a mask. Threshold the HSV image to get everything black
    mask = cv2.inRange(mask, lower_black, upper_black)

    # Invert the mask to get everything but black
    mask = cv2.bitwise_not(mask)

    # Apply the mask to the original image
    masked = cv2.bitwise_and(
        predictions[0].orig_img,
        predictions[0].orig_img,
        mask=mask,
    )

    # bgr to rgb and PIL conversion
    image_output2 = Image.fromarray(masked[:, :, ::-1])
    # return Image.fromarray(mask), image_output2
    return {
        "prediction": predictions[0],
        "mask": Image.fromarray(mask),
        "mask_raw": mask_raw,
        "masked": Image.fromarray(masked[:, :, ::-1]),
    }