|
""" |
|
Gradio app to showcase the bear face recognition model. |
|
""" |
|
|
|
import uuid |
|
from pathlib import Path |
|
from typing import Any, Tuple |
|
|
|
import gradio as gr |
|
from PIL import Image |
|
|
|
from ui import bearid_ui |
|
from utils import load_models, run_pipeline, setup |
|
|
|
|
|
def prediction_to_str(bear_ids: list[str], indexed_k_nearest_individuals) -> str: |
|
""" |
|
Turn the prediction into a human friendly string. |
|
""" |
|
nearest_individuals = [] |
|
|
|
for j in range(len(bear_ids)): |
|
bear_id = bear_ids[j] |
|
nearest_individual = indexed_k_nearest_individuals[bear_id][0] |
|
distance = nearest_individual["distance"] |
|
nearest_individuals.append( |
|
{"bear_id": bear_id, "distance": distance, "data": nearest_individual} |
|
) |
|
|
|
distance_str = "\n".join( |
|
[ |
|
f"- {n['bear_id']} at distance {n['distance']:.2f} in the embedding space" |
|
for n in nearest_individuals |
|
] |
|
) |
|
return f"The model found that the closest individual is {bear_ids[0]}:\n\n{distance_str}" |
|
|
|
|
|
def examples(dir_examples: Path) -> list[Path]: |
|
""" |
|
List the images from the dir_examples directory. |
|
|
|
Returns: |
|
filepaths (list[Path]): list of image filepaths. |
|
""" |
|
return list(dir_examples.glob("*.jpg")) |
|
|
|
|
|
|
|
INPUT_PACKAGED_PIPELINE = Path("./data/09_external/artifacts/packaged_pipeline.zip") |
|
PIPELINE_INSTALL_PATH = Path("./data/06_models/pipeline/metriclearning/") |
|
METRIC_LEARNING_MODEL_FILEPATH = Path( |
|
"./data/06_models/pipeline/metriclearning/bearidentification/model.pt" |
|
) |
|
METRIC_LEARNING_KNN_INDEX_FILEPATH = Path( |
|
"./data/06_models/pipeline/metriclearning/bearidentification/knn.index" |
|
) |
|
INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH = Path( |
|
"./data/06_models/pipeline/metriclearning/bearfacesegmentation/model.pt" |
|
) |
|
CACHE_PIPELINE_RUN = {} |
|
|
|
|
|
def pipeline_run_fn( |
|
loaded_models: dict[str, Any], |
|
pil_image: Image.Image, |
|
param_square_dim: int, |
|
param_k: int, |
|
param_n_samples_per_individual: int, |
|
knn_index_filepath: Path, |
|
cache: dict, |
|
) -> dict: |
|
""" |
|
A simple cached version of the pipeline.run function. |
|
__Note__: It keeps the cache in memory and can grow unbounded. |
|
""" |
|
bytes_image = pil_image.tobytes() |
|
if bytes_image in cache: |
|
print("CACHE HIT") |
|
return cache[bytes_image] |
|
else: |
|
print("CACHE MISS") |
|
results = run_pipeline( |
|
loaded_models=loaded_models, |
|
pil_image=pil_image, |
|
param_square_dim=param_square_dim, |
|
param_k=param_k, |
|
param_n_samples_per_individual=param_n_samples_per_individual, |
|
knn_index_filepath=knn_index_filepath, |
|
) |
|
cache[bytes_image] = results |
|
return results |
|
|
|
|
|
setup( |
|
input_packaged_pipeline=INPUT_PACKAGED_PIPELINE, |
|
install_path=PIPELINE_INSTALL_PATH, |
|
) |
|
|
|
|
|
DIR_EXAMPLES = Path("data/images/") |
|
DEFAULT_IMAGE_INDEX = 0 |
|
PARAM_K = 5 |
|
PARAM_N_SAMPLES_PER_INDIVIDUAL = 4 |
|
PARAM_SQUARE_DIM = 300 |
|
OUTPUT_DIR_PREDICTION = Path("./data/07_model_output/predict/") |
|
|
|
BEAR_ID_TO_NAME = { |
|
"bf_480": "David Martinez", |
|
"bf_813": "Michael Thompson", |
|
"bf_503": "Jessica Brown", |
|
} |
|
|
|
loaded_models = load_models( |
|
filepath_metric_learning_weights=METRIC_LEARNING_MODEL_FILEPATH, |
|
filepath_segmentation_weights=INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH, |
|
) |
|
image_filepaths = examples(dir_examples=DIR_EXAMPLES) |
|
default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX]) |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
'<h1 style="text-align: center; margin-bottom: 1rem">ML model for bear face recognition π»</h1>' |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image( |
|
type="pil", |
|
value=default_value_input, |
|
label="input image", |
|
sources=["upload", "clipboard"], |
|
) |
|
gr.Examples( |
|
examples=image_filepaths, |
|
inputs=image_input, |
|
) |
|
submit_btn = gr.Button(value="Identify", variant="primary") |
|
|
|
with gr.Column(): |
|
with gr.Tab("Prediction"): |
|
with gr.Row(): |
|
output_bear_id = gr.Text(label="Predicted Individual") |
|
output_bear_name = gr.Text(label="Name", visible=False) |
|
output_bear_id_samples = gr.Gallery( |
|
label="Similar faces of the same bear", |
|
preview=True, |
|
visible=False, |
|
) |
|
|
|
with gr.Tab("Details", visible=False) as tab_details: |
|
output_segmentation_image = gr.Image( |
|
type="pil", label="model prediction" |
|
) |
|
output_cropped_image = gr.Image(type="pil", label="cropped bear face") |
|
output_identification_prediction_image = gr.Image( |
|
type="pil", |
|
label="prediction", |
|
) |
|
output_raw = gr.Text(label="raw prediction") |
|
|
|
def submit_fn( |
|
loaded_models: dict[str, Any], |
|
pil_image: Image.Image, |
|
param_square_dim: int, |
|
param_k: int, |
|
param_n_samples_per_individual: int, |
|
) -> dict: |
|
""" |
|
Main function used for the Gradio interface. |
|
|
|
Args: |
|
loaded_models (dict[str, YOLO]): loaded models. |
|
pil_image (PIL): original image picked by the user |
|
|
|
Returns: |
|
bear_id (str): Identification string for the bear. |
|
output_bear_id_samples (list[PIL]): random images of the identified bear. |
|
output_segmentation_image (PIL): image of the segmented bear head. |
|
output_cropped_image (PIL): image of the cropped bear head. |
|
output_identification_prediction_image (PIL): image with prediction from the model. |
|
output_raw (str): string representing the raw prediction from the |
|
model. |
|
tab_details (gr.Column): Gradio Tab column. |
|
""" |
|
result = pipeline_run_fn( |
|
loaded_models=loaded_models, |
|
pil_image=pil_image, |
|
param_square_dim=param_square_dim, |
|
param_k=param_k, |
|
param_n_samples_per_individual=param_n_samples_per_individual, |
|
knn_index_filepath=METRIC_LEARNING_KNN_INDEX_FILEPATH, |
|
cache=CACHE_PIPELINE_RUN, |
|
) |
|
pil_image_segmented_head = result["stages"]["segmentation"]["output"][ |
|
"pil_image" |
|
] |
|
pil_image_cropped_head = result["stages"]["crop"]["output"]["pil_images"][ |
|
"resized" |
|
] |
|
|
|
bear_ids = result["stages"]["identification"]["output"]["bear_ids"] |
|
indexed_samples = result["stages"]["identification"]["output"][ |
|
"indexed_samples" |
|
] |
|
indexed_k_nearest_individuals = result["stages"]["identification"]["output"][ |
|
"indexed_k_nearest_individuals" |
|
] |
|
bear_id = bear_ids[0] |
|
|
|
output_filepath = OUTPUT_DIR_PREDICTION / f"prediction_{uuid.uuid4()}.png" |
|
output_filepath.parent.mkdir(parents=True, exist_ok=True) |
|
if output_filepath.exists(): |
|
output_filepath.unlink() |
|
|
|
bearid_ui( |
|
pil_image_chip=pil_image_cropped_head, |
|
indexed_k_nearest_individuals=indexed_k_nearest_individuals, |
|
indexed_samples=indexed_samples, |
|
save_filepath=output_filepath, |
|
k_closest_neighbors=param_k, |
|
) |
|
|
|
pil_image_identification_prediction = Image.open(output_filepath) |
|
|
|
raw_prediction_str = prediction_to_str( |
|
bear_ids=bear_ids, |
|
indexed_k_nearest_individuals=indexed_k_nearest_individuals, |
|
) |
|
|
|
return { |
|
output_bear_id: bear_id, |
|
output_bear_name: gr.Text( |
|
BEAR_ID_TO_NAME.get(bear_id, "Laura Anderson"), visible=True |
|
), |
|
output_bear_id_samples: gr.Gallery( |
|
[Image.open(fp) for fp in indexed_samples[bear_id]], visible=True |
|
), |
|
output_segmentation_image: pil_image_segmented_head, |
|
output_cropped_image: pil_image_cropped_head, |
|
output_identification_prediction_image: pil_image_identification_prediction, |
|
output_raw: gr.Text( |
|
raw_prediction_str, lines=len(raw_prediction_str.split("\n")) |
|
), |
|
tab_details: gr.Column(visible=True), |
|
} |
|
|
|
submit_btn.click( |
|
fn=lambda pil_image: submit_fn( |
|
loaded_models=loaded_models, |
|
pil_image=pil_image, |
|
param_square_dim=PARAM_SQUARE_DIM, |
|
param_k=PARAM_K, |
|
param_n_samples_per_individual=PARAM_N_SAMPLES_PER_INDIVIDUAL, |
|
), |
|
inputs=image_input, |
|
outputs=[ |
|
output_bear_id, |
|
output_bear_name, |
|
output_bear_id_samples, |
|
output_segmentation_image, |
|
output_cropped_image, |
|
output_identification_prediction_image, |
|
output_raw, |
|
tab_details, |
|
], |
|
) |
|
|
|
demo.launch() |
|
|