Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# Copyright (c) Microsoft Corporation. All rights reserved. | |
# Licensed under the MIT License. | |
""" Demo using Gradio interface""" | |
#%% | |
# Importing basic libraries | |
import os | |
import time | |
from PIL import Image | |
import supervision as sv | |
import gradio as gr | |
from zipfile import ZipFile | |
from torch.utils.data import DataLoader | |
#%% | |
# Importing the models, dataset, transformations, and utility functions from PytorchWildlife | |
from PytorchWildlife.models import detection as pw_detection | |
from PytorchWildlife.models import classification as pw_classification | |
from PytorchWildlife.data import transforms as pw_trans | |
from PytorchWildlife.data import datasets as pw_data | |
from PytorchWildlife import utils as pw_utils | |
#%% | |
# Setting the device to use for computations ('cuda' indicates GPU) | |
DEVICE = "cpu" | |
# Initializing a supervision box annotator for visualizing detections | |
box_annotator = sv.BoxAnnotator(thickness=4, text_thickness=4, text_scale=2) | |
# Initializing the detection and classification models | |
detection_model = None | |
classification_model = None | |
# Defining transformations for detection and classification | |
trans_det = None | |
trans_clf = None | |
#%% Defining functions for different detection scenarios | |
def load_models(det, clf): | |
global detection_model, classification_model, trans_det, trans_clf | |
detection_model = pw_detection.__dict__[det](device=DEVICE, pretrained=True) | |
if clf != "None": | |
classification_model = pw_classification.__dict__[clf](device=DEVICE, pretrained=True) | |
trans_det = pw_trans.MegaDetector_v5_Transform(target_size=detection_model.IMAGE_SIZE, | |
stride=detection_model.STRIDE) | |
trans_clf = pw_trans.Classification_Inference_Transform(target_size=224) | |
return "Loaded Detector: {}. Loaded Classifier: {}".format(det, clf) | |
def single_image_detection(input_img, det_conf_thres, clf_conf_thres, img_index=None): | |
"""Performs detection on a single image and returns an annotated image. | |
Args: | |
input_img (np.ndarray): Input image in numpy array format defaulted by Gradio. | |
det_conf_thre (float): Confidence threshold for detection. | |
clf_conf_thre (float): Confidence threshold for classification. | |
img_index: Image index identifier. | |
Returns: | |
annotated_img (PIL.Image.Image): Annotated image with bounding box instances. | |
""" | |
results_det = detection_model.single_image_detection(trans_det(input_img), | |
input_img.shape, | |
img_path=img_index, | |
conf_thres=det_conf_thres) | |
if classification_model is not None: | |
labels = [] | |
for xyxy, det_id in zip(results_det["detections"].xyxy, results_det["detections"].class_id): | |
# Only run classifier when detection class is animal | |
if det_id == 0: | |
cropped_image = sv.crop_image(image=input_img, xyxy=xyxy) | |
results_clf = classification_model.single_image_classification(trans_clf(Image.fromarray(cropped_image))) | |
labels.append("{} {:.2f}".format(results_clf["prediction"] if results_clf["confidence"] > clf_conf_thres else "Unknown", | |
results_clf["confidence"])) | |
else: | |
labels = results_det["labels"] | |
else: | |
labels = results_det["labels"] | |
annotated_img = box_annotator.annotate(scene=input_img, detections=results_det["detections"], labels=labels) | |
return annotated_img | |
def batch_detection(zip_file, det_conf_thres): | |
"""Perform detection on a batch of images from a zip file and return path to results JSON. | |
Args: | |
zip_file (File): Zip file containing images. | |
det_conf_thre (float): Confidence threshold for detection. | |
clf_conf_thre (float): Confidence threshold for classification. | |
Returns: | |
json_save_path (str): Path to the JSON file containing detection results. | |
""" | |
extract_path = os.path.join("..","temp","zip_upload") | |
json_save_path = os.path.join(extract_path, "results.json") | |
with ZipFile(zip_file.name) as zfile: | |
zfile.extractall(extract_path) | |
#tgt_folder_path = os.path.join(extract_path, zip_file.name.rsplit(os.sep, 1)[1].rstrip(".zip")) | |
tgt_folder_path = os.path.join(extract_path) | |
det_dataset = pw_data.DetectionImageFolder(tgt_folder_path, transform=trans_det) | |
det_loader = DataLoader(det_dataset, batch_size=32, shuffle=False, | |
pin_memory=True, num_workers=8, drop_last=False) | |
det_results = detection_model.batch_image_detection(det_loader, conf_thres=det_conf_thres, id_strip=tgt_folder_path) | |
if classification_model is not None: | |
clf_dataset = pw_data.DetectionCrops( | |
det_results, | |
transform=pw_trans.Classification_Inference_Transform(target_size=224), | |
path_head=tgt_folder_path | |
) | |
clf_loader = DataLoader(clf_dataset, batch_size=32, shuffle=False, | |
pin_memory=True, num_workers=8, drop_last=False) | |
clf_results = classification_model.batch_image_classification(clf_loader, id_strip=tgt_folder_path) | |
pw_utils.save_detection_classification_json(det_results=det_results, | |
clf_results=clf_results, | |
det_categories=detection_model.CLASS_NAMES, | |
clf_categories=classification_model.CLASS_NAMES, | |
output_path=json_save_path) | |
else: | |
pw_utils.save_detection_json(det_results, json_save_path, categories=detection_model.CLASS_NAMES) | |
return json_save_path | |
def video_detection(video, det_conf_thres, clf_conf_thres, target_fps): | |
"""Perform detection on a video and return path to processed video. | |
Args: | |
video (str): Video source path. | |
det_conf_thre (float): Confidence threshold for detection. | |
clf_conf_thre (float): Confidence threshold for classification. | |
""" | |
def callback(frame, index): | |
annotated_frame = single_image_detection(frame, | |
img_index=index, | |
det_conf_thres=det_conf_thres, | |
clf_conf_thres=clf_conf_thres) | |
return annotated_frame | |
target_path = "../temp/video_detection.mp4" | |
pw_utils.process_video(source_path=video, target_path=target_path, | |
callback=callback, target_fps=target_fps) | |
return target_path | |
#%% Building Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# Pytorch-Wildlife Demo.") | |
with gr.Row(): | |
det_drop = gr.Dropdown( | |
["MegaDetectorV5"], | |
label="Detection model", | |
info="Will add more detection models!", | |
value="MegaDetectorV5" | |
) | |
clf_drop = gr.Dropdown( | |
["None", "AI4GOpossum", "AI4GAmazonRainforest"], | |
label="Classification model", | |
info="Will add more classification models!", | |
value="None" | |
) | |
with gr.Column(): | |
load_but = gr.Button("Load Models!") | |
load_out = gr.Text("NO MODEL LOADED!!", label="Loaded models:") | |
with gr.Tab("Single Image Process"): | |
with gr.Row(): | |
with gr.Column(): | |
sgl_in = gr.Image() | |
sgl_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) | |
sgl_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7) | |
sgl_out = gr.Image() | |
sgl_but = gr.Button("Detect Animals!") | |
with gr.Tab("Batch Image Process"): | |
with gr.Row(): | |
with gr.Column(): | |
bth_in = gr.File(label="Upload zip file.") | |
bth_conf_sl = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) | |
bth_out = gr.File(label="Detection Results JSON.", height=200) | |
bth_but = gr.Button("Detect Animals!") | |
with gr.Tab("Single Video Process"): | |
with gr.Row(): | |
with gr.Column(): | |
vid_in = gr.Video(label="Upload a video.") | |
vid_conf_sl_det = gr.Slider(0, 1, label="Detection Confidence Threshold", value=0.2) | |
vid_conf_sl_clf = gr.Slider(0, 1, label="Classification Confidence Threshold", value=0.7) | |
vid_fr = gr.Dropdown([5, 10, 30], label="Output video framerate", value=30) | |
vid_out = gr.Video() | |
vid_but = gr.Button("Detect Animals!") | |
load_but.click(load_models, inputs=[det_drop, clf_drop], outputs=load_out) | |
sgl_but.click(single_image_detection, inputs=[sgl_in, sgl_conf_sl_det, sgl_conf_sl_clf], outputs=sgl_out) | |
bth_but.click(batch_detection, inputs=[bth_in, bth_conf_sl], outputs=bth_out) | |
vid_but.click(video_detection, inputs=[vid_in, vid_conf_sl_det, vid_conf_sl_clf, vid_fr], outputs=vid_out) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |