Spaces:
Running
Running
""" | |
Gradio app to showcase the pyronear model for salmon vision. | |
""" | |
from collections import Counter | |
from pathlib import Path | |
from typing import Any, Tuple | |
import torch | |
import gradio as gr | |
import numpy as np | |
from ultralytics import YOLO | |
def bgr_to_rgb(a: np.ndarray) -> np.ndarray: | |
""" | |
Turn a BGR numpy array into a RGB numpy array when the array `a` represents | |
an image. | |
""" | |
return a[:, :, ::-1] | |
def has_values(maybe_tensor: torch.Tensor | None) -> bool: | |
""" | |
Check whether the `maybe_tensor` contains items. | |
""" | |
if maybe_tensor is None: | |
return False | |
elif isinstance(maybe_tensor, torch.Tensor): | |
if len(maybe_tensor) == 0: | |
return False | |
else: | |
return True | |
def analyze_predictions(yolo_predictions) -> dict[str, Any]: | |
""" | |
Analyze the raw `yolo_predictions` and outputs a dict containg information. | |
Args: | |
yolo_predictions: result of calling model.track() on a video | |
Returns: | |
counts (int): number of distinct identifiers. | |
ids (set[int]): all the assigned identifiers. | |
detected_species (dict[int, int]): mapping from identifier to instance class | |
names (list[str]): the class names used by the model | |
""" | |
if len(yolo_predictions) == 0: | |
return { | |
"counts": 0, | |
"ids": set(), | |
"detected_species": {}, | |
"names": None, | |
} | |
else: | |
names = yolo_predictions[0].names | |
ids = set() | |
for prediction in yolo_predictions: | |
if has_values(prediction.boxes.id): | |
for id in prediction.boxes.id.numpy().astype("int"): | |
ids.add(id.item()) | |
detected_species = {} | |
for id in ids: | |
counter = Counter() | |
for prediction in yolo_predictions: | |
if has_values(prediction.boxes.id): | |
for idd, klass in zip( | |
prediction.boxes.id.numpy().astype("int"), | |
prediction.boxes.cls.numpy().astype("int"), | |
): | |
if idd.item() == id: | |
counter[klass.item()] += 1 | |
selected_class = counter.most_common(1)[0][0] | |
detected_species[id] = selected_class | |
return { | |
"counts": len(ids), | |
"ids": ids, | |
"detected_species": detected_species, | |
"names": names, | |
} | |
def prediction_to_str(yolo_predictions) -> str: | |
""" | |
Turn the yolo_predictions into a human friendly string. | |
""" | |
if len(yolo_predictions) == 0: | |
return "No prediction" | |
else: | |
result = analyze_predictions(yolo_predictions=yolo_predictions) | |
names = result["names"] | |
detected_species = result["detected_species"] | |
ids = result["ids"] | |
summary_str = "\n".join( | |
[ | |
f"- The fish with id {id} is a {names.get(klass, 'Unknown')}" | |
for id, klass in detected_species.items() | |
] | |
) | |
print(summary_str) | |
return f"Detected {len(ids)} salmons in the video clip with ids {ids}:\n{summary_str}" | |
def interface_fn(model: YOLO, video_filepath: Path) -> Tuple[Path, str]: | |
""" | |
Main interface function that runs the model on the provided pil_image and | |
returns the exepected tuple to populate the gradio interface. | |
Args: | |
model (YOLO): Loaded ultralytics YOLO model. | |
pil_image (PIL): image to run inference on. | |
Returns: | |
pil_image_with_prediction (PIL): image with prediction from the model. | |
raw_prediction_str (str): string representing the raw prediction from the | |
model. | |
""" | |
project = "runs/track/" | |
name = video_filepath.stem | |
predictions = model.track( | |
source=video_filepath, | |
save=True, | |
tracker="bytetrack.yaml", | |
exist_ok=True, | |
project=project, | |
name=name, | |
) | |
filepath_video_prediction = Path(f"{project}/{name}/{name}.avi") | |
raw_prediction_str = prediction_to_str(yolo_predictions=predictions) | |
return (filepath_video_prediction, raw_prediction_str) | |
def examples(dir_examples: Path) -> list[Path]: | |
""" | |
List the images from the dir_examples directory. | |
Returns: | |
filepaths (list[Path]): list of image filepaths. | |
""" | |
return list(dir_examples.glob("*.mp4")) | |
def load_model(filepath_weights: Path) -> YOLO: | |
""" | |
Load the YOLO model given the filepath_weights. | |
""" | |
return YOLO(filepath_weights) | |
# Main Gradio interface | |
MODEL_FILEPATH_WEIGHTS = Path("data/model/weights.pt") | |
DIR_EXAMPLES = Path("data/videos/") | |
DEFAULT_IMAGE_INDEX = 0 | |
with gr.Blocks() as demo: | |
model = load_model(MODEL_FILEPATH_WEIGHTS) | |
videos_filepaths = examples(dir_examples=DIR_EXAMPLES) | |
print(f"videos_filepaths: {videos_filepaths}") | |
default_value_input = videos_filepaths[DEFAULT_IMAGE_INDEX] | |
input = gr.Video( | |
value=default_value_input, | |
format="mp4", | |
autoplay=True, | |
loop=True, | |
label="input video", | |
sources=["upload"], | |
) | |
output_video = gr.Video( | |
format="mp4", | |
label="model prediction", | |
autoplay=True, | |
loop=True, | |
) | |
output_raw = gr.Text(label="raw prediction") | |
fn = lambda video_filepath: interface_fn( | |
model=model, video_filepath=Path(video_filepath) | |
) | |
gr.Interface( | |
title="ML model for wild salmon migration monitoring π", | |
fn=fn, | |
inputs=input, | |
outputs=[output_video, output_raw], | |
examples=videos_filepaths, | |
flagging_mode="never", | |
) | |
demo.launch() | |