Spaces:
Running
Running
File size: 5,670 Bytes
b5f64dc 7b4f324 b5f64dc 7b4f324 b5f64dc 7b4f324 b5f64dc 7b4f324 b5f64dc 7b4f324 b5f64dc 7b4f324 b5f64dc f6ad7e6 b5f64dc f6ad7e6 b5f64dc 7b4f324 b5f64dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
"""
Gradio app to showcase the pyronear model for salmon vision.
"""
from collections import Counter
from pathlib import Path
from typing import Any, Tuple
import torch
import gradio as gr
import numpy as np
from ultralytics import YOLO
def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
"""
Turn a BGR numpy array into a RGB numpy array when the array `a` represents
an image.
"""
return a[:, :, ::-1]
def has_values(maybe_tensor: torch.Tensor | None) -> bool:
"""
Check whether the `maybe_tensor` contains items.
"""
if maybe_tensor is None:
return False
elif isinstance(maybe_tensor, torch.Tensor):
if len(maybe_tensor) == 0:
return False
else:
return True
def analyze_predictions(yolo_predictions) -> dict[str, Any]:
"""
Analyze the raw `yolo_predictions` and outputs a dict containg information.
Args:
yolo_predictions: result of calling model.track() on a video
Returns:
counts (int): number of distinct identifiers.
ids (set[int]): all the assigned identifiers.
detected_species (dict[int, int]): mapping from identifier to instance class
names (list[str]): the class names used by the model
"""
if len(yolo_predictions) == 0:
return {
"counts": 0,
"ids": set(),
"detected_species": {},
"names": None,
}
else:
names = yolo_predictions[0].names
ids = set()
for prediction in yolo_predictions:
if has_values(prediction.boxes.id):
for id in prediction.boxes.id.numpy().astype("int"):
ids.add(id.item())
detected_species = {}
for id in ids:
counter = Counter()
for prediction in yolo_predictions:
if has_values(prediction.boxes.id):
for idd, klass in zip(
prediction.boxes.id.numpy().astype("int"),
prediction.boxes.cls.numpy().astype("int"),
):
if idd.item() == id:
counter[klass.item()] += 1
selected_class = counter.most_common(1)[0][0]
detected_species[id] = selected_class
return {
"counts": len(ids),
"ids": ids,
"detected_species": detected_species,
"names": names,
}
def prediction_to_str(yolo_predictions) -> str:
"""
Turn the yolo_predictions into a human friendly string.
"""
if len(yolo_predictions) == 0:
return "No prediction"
else:
result = analyze_predictions(yolo_predictions=yolo_predictions)
names = result["names"]
detected_species = result["detected_species"]
ids = result["ids"]
summary_str = "\n".join(
[
f"- The fish with id {id} is a {names.get(klass, 'Unknown')}"
for id, klass in detected_species.items()
]
)
print(summary_str)
return f"Detected {len(ids)} salmons in the video clip with ids {ids}:\n{summary_str}"
def interface_fn(model: YOLO, video_filepath: Path) -> Tuple[Path, str]:
"""
Main interface function that runs the model on the provided pil_image and
returns the exepected tuple to populate the gradio interface.
Args:
model (YOLO): Loaded ultralytics YOLO model.
pil_image (PIL): image to run inference on.
Returns:
pil_image_with_prediction (PIL): image with prediction from the model.
raw_prediction_str (str): string representing the raw prediction from the
model.
"""
project = "runs/track/"
name = video_filepath.stem
predictions = model.track(
source=video_filepath,
save=True,
tracker="bytetrack.yaml",
exist_ok=True,
project=project,
name=name,
)
filepath_video_prediction = Path(f"{project}/{name}/{name}.avi")
raw_prediction_str = prediction_to_str(yolo_predictions=predictions)
return (filepath_video_prediction, raw_prediction_str)
def examples(dir_examples: Path) -> list[Path]:
"""
List the images from the dir_examples directory.
Returns:
filepaths (list[Path]): list of image filepaths.
"""
return list(dir_examples.glob("*.mp4"))
def load_model(filepath_weights: Path) -> YOLO:
"""
Load the YOLO model given the filepath_weights.
"""
return YOLO(filepath_weights)
# Main Gradio interface
MODEL_FILEPATH_WEIGHTS = Path("data/model/weights.pt")
DIR_EXAMPLES = Path("data/videos/")
DEFAULT_IMAGE_INDEX = 0
with gr.Blocks() as demo:
model = load_model(MODEL_FILEPATH_WEIGHTS)
videos_filepaths = examples(dir_examples=DIR_EXAMPLES)
print(f"videos_filepaths: {videos_filepaths}")
default_value_input = videos_filepaths[DEFAULT_IMAGE_INDEX]
input = gr.Video(
value=default_value_input,
format="mp4",
autoplay=True,
loop=True,
label="input video",
sources=["upload"],
)
output_video = gr.Video(
format="mp4",
label="model prediction",
autoplay=True,
loop=True,
)
output_raw = gr.Text(label="raw prediction")
fn = lambda video_filepath: interface_fn(
model=model, video_filepath=Path(video_filepath)
)
gr.Interface(
title="ML model for wild salmon migration monitoring π",
fn=fn,
inputs=input,
outputs=[output_video, output_raw],
examples=videos_filepaths,
flagging_mode="never",
)
demo.launch()
|