achouffe's picture
style(title): add css for title
197814f verified
"""
Gradio app to showcase the bear face recognition model.
"""
import uuid
from pathlib import Path
from typing import Any, Tuple
import gradio as gr
from PIL import Image
from ui import bearid_ui
from utils import load_models, run_pipeline, setup
def prediction_to_str(bear_ids: list[str], indexed_k_nearest_individuals) -> str:
"""
Turn the prediction into a human friendly string.
"""
nearest_individuals = []
for j in range(len(bear_ids)):
bear_id = bear_ids[j]
nearest_individual = indexed_k_nearest_individuals[bear_id][0]
distance = nearest_individual["distance"]
nearest_individuals.append(
{"bear_id": bear_id, "distance": distance, "data": nearest_individual}
)
distance_str = "\n".join(
[
f"- {n['bear_id']} at distance {n['distance']:.2f} in the embedding space"
for n in nearest_individuals
]
)
return f"The model found that the closest individual is {bear_ids[0]}:\n\n{distance_str}"
def examples(dir_examples: Path) -> list[Path]:
"""
List the images from the dir_examples directory.
Returns:
filepaths (list[Path]): list of image filepaths.
"""
return list(dir_examples.glob("*.jpg"))
# Setting up the model artifacts
INPUT_PACKAGED_PIPELINE = Path("./data/09_external/artifacts/packaged_pipeline.zip")
PIPELINE_INSTALL_PATH = Path("./data/06_models/pipeline/metriclearning/")
METRIC_LEARNING_MODEL_FILEPATH = Path(
"./data/06_models/pipeline/metriclearning/bearidentification/model.pt"
)
METRIC_LEARNING_KNN_INDEX_FILEPATH = Path(
"./data/06_models/pipeline/metriclearning/bearidentification/knn.index"
)
INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH = Path(
"./data/06_models/pipeline/metriclearning/bearfacesegmentation/model.pt"
)
CACHE_PIPELINE_RUN = {}
def pipeline_run_fn(
loaded_models: dict[str, Any],
pil_image: Image.Image,
param_square_dim: int,
param_k: int,
param_n_samples_per_individual: int,
knn_index_filepath: Path,
cache: dict,
) -> dict:
"""
A simple cached version of the pipeline.run function.
__Note__: It keeps the cache in memory and can grow unbounded.
"""
bytes_image = pil_image.tobytes()
if bytes_image in cache:
print("CACHE HIT")
return cache[bytes_image]
else:
print("CACHE MISS")
results = run_pipeline(
loaded_models=loaded_models,
pil_image=pil_image,
param_square_dim=param_square_dim,
param_k=param_k,
param_n_samples_per_individual=param_n_samples_per_individual,
knn_index_filepath=knn_index_filepath,
)
cache[bytes_image] = results
return results
setup(
input_packaged_pipeline=INPUT_PACKAGED_PIPELINE,
install_path=PIPELINE_INSTALL_PATH,
)
# Main Gradio interface
DIR_EXAMPLES = Path("data/images/")
DEFAULT_IMAGE_INDEX = 0
PARAM_K = 5
PARAM_N_SAMPLES_PER_INDIVIDUAL = 4
PARAM_SQUARE_DIM = 300
OUTPUT_DIR_PREDICTION = Path("./data/07_model_output/predict/")
BEAR_ID_TO_NAME = {
"bf_480": "David Martinez",
"bf_813": "Michael Thompson",
"bf_503": "Jessica Brown",
}
loaded_models = load_models(
filepath_metric_learning_weights=METRIC_LEARNING_MODEL_FILEPATH,
filepath_segmentation_weights=INSTANCE_SEGMENTATION_WEIGHTS_FILEPATH,
)
image_filepaths = examples(dir_examples=DIR_EXAMPLES)
default_value_input = Image.open(image_filepaths[DEFAULT_IMAGE_INDEX])
with gr.Blocks() as demo:
gr.HTML(
'<h1 style="text-align: center; margin-bottom: 1rem">ML model for bear face recognition 🐻</h1>'
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
type="pil",
value=default_value_input,
label="input image",
sources=["upload", "clipboard"],
)
gr.Examples(
examples=image_filepaths,
inputs=image_input,
)
submit_btn = gr.Button(value="Identify", variant="primary")
with gr.Column():
with gr.Tab("Prediction"):
with gr.Row():
output_bear_id = gr.Text(label="Predicted Individual")
output_bear_name = gr.Text(label="Name", visible=False)
output_bear_id_samples = gr.Gallery(
label="Similar faces of the same bear",
preview=True,
visible=False,
)
with gr.Tab("Details", visible=False) as tab_details:
output_segmentation_image = gr.Image(
type="pil", label="model prediction"
)
output_cropped_image = gr.Image(type="pil", label="cropped bear face")
output_identification_prediction_image = gr.Image(
type="pil",
label="prediction",
)
output_raw = gr.Text(label="raw prediction")
def submit_fn(
loaded_models: dict[str, Any],
pil_image: Image.Image,
param_square_dim: int,
param_k: int,
param_n_samples_per_individual: int,
) -> dict:
"""
Main function used for the Gradio interface.
Args:
loaded_models (dict[str, YOLO]): loaded models.
pil_image (PIL): original image picked by the user
Returns:
bear_id (str): Identification string for the bear.
output_bear_id_samples (list[PIL]): random images of the identified bear.
output_segmentation_image (PIL): image of the segmented bear head.
output_cropped_image (PIL): image of the cropped bear head.
output_identification_prediction_image (PIL): image with prediction from the model.
output_raw (str): string representing the raw prediction from the
model.
tab_details (gr.Column): Gradio Tab column.
"""
result = pipeline_run_fn(
loaded_models=loaded_models,
pil_image=pil_image,
param_square_dim=param_square_dim,
param_k=param_k,
param_n_samples_per_individual=param_n_samples_per_individual,
knn_index_filepath=METRIC_LEARNING_KNN_INDEX_FILEPATH,
cache=CACHE_PIPELINE_RUN,
)
pil_image_segmented_head = result["stages"]["segmentation"]["output"][
"pil_image"
]
pil_image_cropped_head = result["stages"]["crop"]["output"]["pil_images"][
"resized"
]
bear_ids = result["stages"]["identification"]["output"]["bear_ids"]
indexed_samples = result["stages"]["identification"]["output"][
"indexed_samples"
]
indexed_k_nearest_individuals = result["stages"]["identification"]["output"][
"indexed_k_nearest_individuals"
]
bear_id = bear_ids[0]
output_filepath = OUTPUT_DIR_PREDICTION / f"prediction_{uuid.uuid4()}.png"
output_filepath.parent.mkdir(parents=True, exist_ok=True)
if output_filepath.exists():
output_filepath.unlink()
bearid_ui(
pil_image_chip=pil_image_cropped_head,
indexed_k_nearest_individuals=indexed_k_nearest_individuals,
indexed_samples=indexed_samples,
save_filepath=output_filepath,
k_closest_neighbors=param_k,
)
pil_image_identification_prediction = Image.open(output_filepath)
raw_prediction_str = prediction_to_str(
bear_ids=bear_ids,
indexed_k_nearest_individuals=indexed_k_nearest_individuals,
)
return {
output_bear_id: bear_id,
output_bear_name: gr.Text(
BEAR_ID_TO_NAME.get(bear_id, "Laura Anderson"), visible=True
),
output_bear_id_samples: gr.Gallery(
[Image.open(fp) for fp in indexed_samples[bear_id]], visible=True
),
output_segmentation_image: pil_image_segmented_head,
output_cropped_image: pil_image_cropped_head,
output_identification_prediction_image: pil_image_identification_prediction,
output_raw: gr.Text(
raw_prediction_str, lines=len(raw_prediction_str.split("\n"))
),
tab_details: gr.Column(visible=True),
}
submit_btn.click(
fn=lambda pil_image: submit_fn(
loaded_models=loaded_models,
pil_image=pil_image,
param_square_dim=PARAM_SQUARE_DIM,
param_k=PARAM_K,
param_n_samples_per_individual=PARAM_N_SAMPLES_PER_INDIVIDUAL,
),
inputs=image_input,
outputs=[
output_bear_id,
output_bear_name,
output_bear_id_samples,
output_segmentation_image,
output_cropped_image,
output_identification_prediction_image,
output_raw,
tab_details,
],
)
demo.launch()