File size: 5,670 Bytes
b5f64dc
7b4f324
b5f64dc
 
 
 
 
7b4f324
b5f64dc
 
 
 
 
 
 
 
 
 
 
 
 
7b4f324
 
 
 
 
 
 
 
 
 
 
 
b5f64dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b4f324
b5f64dc
 
 
 
 
 
7b4f324
b5f64dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b4f324
b5f64dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6ad7e6
 
b5f64dc
 
 
f6ad7e6
 
 
 
 
 
b5f64dc
 
7b4f324
b5f64dc
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Gradio app to showcase the pyronear model for salmon vision.
"""

from collections import Counter
from pathlib import Path
from typing import Any, Tuple
import torch

import gradio as gr
import numpy as np
from ultralytics import YOLO


def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
    """
    Turn a BGR numpy array into a RGB numpy array when the array `a` represents
    an image.
    """
    return a[:, :, ::-1]

def has_values(maybe_tensor: torch.Tensor | None) -> bool:
    """
    Check whether the `maybe_tensor` contains items.
    """
    if maybe_tensor is None:
        return False
    elif isinstance(maybe_tensor, torch.Tensor):
        if len(maybe_tensor) == 0:
            return False
        else:
            return True


def analyze_predictions(yolo_predictions) -> dict[str, Any]:
    """
    Analyze the raw `yolo_predictions` and outputs a dict containg information.

    Args:
        yolo_predictions: result of calling model.track() on a video

    Returns:
        counts (int): number of distinct identifiers.
        ids (set[int]): all the assigned identifiers.
        detected_species (dict[int, int]): mapping from identifier to instance class
        names (list[str]): the class names used by the model
    """
    if len(yolo_predictions) == 0:
        return {
            "counts": 0,
            "ids": set(),
            "detected_species": {},
            "names": None,
        }
    else:
        names = yolo_predictions[0].names
        ids = set()
        for prediction in yolo_predictions:
            if has_values(prediction.boxes.id):
                for id in prediction.boxes.id.numpy().astype("int"):
                    ids.add(id.item())
        detected_species = {}
        for id in ids:
            counter = Counter()
            for prediction in yolo_predictions:
                if has_values(prediction.boxes.id):
                    for idd, klass in zip(
                        prediction.boxes.id.numpy().astype("int"),
                        prediction.boxes.cls.numpy().astype("int"),
                    ):
                        if idd.item() == id:
                            counter[klass.item()] += 1
            selected_class = counter.most_common(1)[0][0]
            detected_species[id] = selected_class
        return {
            "counts": len(ids),
            "ids": ids,
            "detected_species": detected_species,
            "names": names,
        }


def prediction_to_str(yolo_predictions) -> str:
    """
    Turn the yolo_predictions into a human friendly string.
    """
    if len(yolo_predictions) == 0:
        return "No prediction"
    else:
        result = analyze_predictions(yolo_predictions=yolo_predictions)
        names = result["names"]
        detected_species = result["detected_species"]
        ids = result["ids"]
        summary_str = "\n".join(
            [
                f"- The fish with id {id} is a {names.get(klass, 'Unknown')}"
                for id, klass in detected_species.items()
            ]
        )
        print(summary_str)
        return f"Detected {len(ids)} salmons in the video clip with ids {ids}:\n{summary_str}"


def interface_fn(model: YOLO, video_filepath: Path) -> Tuple[Path, str]:
    """
    Main interface function that runs the model on the provided pil_image and
    returns the exepected tuple to populate the gradio interface.

    Args:
        model (YOLO): Loaded ultralytics YOLO model.
        pil_image (PIL): image to run inference on.

    Returns:
        pil_image_with_prediction (PIL): image with prediction from the model.
        raw_prediction_str (str): string representing the raw prediction from the
        model.
    """
    project = "runs/track/"
    name = video_filepath.stem
    predictions = model.track(
        source=video_filepath,
        save=True,
        tracker="bytetrack.yaml",
        exist_ok=True,
        project=project,
        name=name,
    )
    filepath_video_prediction = Path(f"{project}/{name}/{name}.avi")
    raw_prediction_str = prediction_to_str(yolo_predictions=predictions)
    return (filepath_video_prediction, raw_prediction_str)


def examples(dir_examples: Path) -> list[Path]:
    """
    List the images from the dir_examples directory.

    Returns:
        filepaths (list[Path]): list of image filepaths.
    """
    return list(dir_examples.glob("*.mp4"))


def load_model(filepath_weights: Path) -> YOLO:
    """
    Load the YOLO model given the filepath_weights.
    """
    return YOLO(filepath_weights)


# Main Gradio interface

MODEL_FILEPATH_WEIGHTS = Path("data/model/weights.pt")
DIR_EXAMPLES = Path("data/videos/")
DEFAULT_IMAGE_INDEX = 0

with gr.Blocks() as demo:
    model = load_model(MODEL_FILEPATH_WEIGHTS)
    videos_filepaths = examples(dir_examples=DIR_EXAMPLES)
    print(f"videos_filepaths: {videos_filepaths}")
    default_value_input = videos_filepaths[DEFAULT_IMAGE_INDEX]
    input = gr.Video(
        value=default_value_input,
        format="mp4",
        autoplay=True,
        loop=True,
        label="input video",
        sources=["upload"],
    )
    output_video = gr.Video(
        format="mp4",
        label="model prediction",
        autoplay=True,
        loop=True,
    )
    output_raw = gr.Text(label="raw prediction")

    fn = lambda video_filepath: interface_fn(
        model=model, video_filepath=Path(video_filepath)
    )
    gr.Interface(
        title="ML model for wild salmon migration monitoring  🐟",
        fn=fn,
        inputs=input,
        outputs=[output_video, output_raw],
        examples=videos_filepaths,
        flagging_mode="never",
    )

demo.launch()