Yolo_V3 / app.py
Gosula's picture
Update app.py
a5806ac verified
import gradio as gr
import numpy as np
import cv2
import torch
from torchvision import datasets, transforms
from PIL import Image
#from train import YOLOv3Lightning
from utils import non_max_suppression, plot_image, cells_to_bboxes
from dataset import YOLODataset
import config
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model import YoloVersion3
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Load the model
model = YoloVersion3( )
model.load_state_dict(torch.load('Yolov3.pth', map_location=torch.device('cpu')), strict=False)
model.eval()
# Anchor
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to("cpu")
test_transforms = A.Compose(
[
A.LongestMaxSize(max_size=416),
A.PadIfNeeded(
min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
),
A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
ToTensorV2(),
]
)
def plot_image(image, boxes):
"""Plots predicted bounding boxes on the image"""
cmap = plt.get_cmap("tab20b")
class_labels = config.PASCAL_CLASSES
colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
im = np.array(image)
height, width, _ = im.shape
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
ax.imshow(im)
# Create a Rectangle patch
for box in boxes:
assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
class_pred = box[0]
box = box[2:]
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=2,
edgecolor=colors[int(class_pred)],
facecolor="none",
)
# Add the patch to the Axes
ax.add_patch(rect)
plt.text(
upper_left_x * width,
upper_left_y * height,
s=class_labels[int(class_pred)],
color="white",
verticalalignment="top",
bbox={"color": colors[int(class_pred)], "pad": 0},
)
# plt.show()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
ax.axis('off')
plt.savefig('inference.png')
# Inference function
def inference(inp_image):
inp_image=inp_image
org_image = inp_image
transform = test_transforms
x = transform(image=inp_image)["image"]
x=x.unsqueeze(0)
# Perform inference
device = "cpu"
model.to(device)
# Ensure model is in evaluation mode
model.eval()
# Perform inference
with torch.no_grad():
out = model(x)
#out = model(x)
# Ensure model is in evaluation mode
bboxes = [[] for _ in range(x.shape[0])]
for i in range(3):
batch_size, A, S, _, _ = out[i].shape
anchor = scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(
out[i], anchor, S=S, is_preds=True
)
for idx, (box) in enumerate(boxes_scale_i):
bboxes[idx] += box
nms_boxes = non_max_suppression(
bboxes[0], iou_threshold=0.5, threshold=0.6, box_format="midpoint",
)
# print(nms_boxes[0])
width_ratio = org_image.shape[1] / 416
height_ratio = org_image.shape[0] / 416
plot_image(org_image, nms_boxes)
plotted_img = 'inference.png'
return plotted_img
inputs = gr.inputs.Image(label="Original Image")
outputs = gr.outputs.Image(type="pil",label="Output Image")
title = "YOLOv3 model trained on PASCAL VOC Dataset"
description = "YOLOv3 object detection using Gradio demo"
examples = [['examples/car.jpg'], ['examples/home.jpg'],['examples/train.jpg'],['examples/train_persons.jpg']]
gr.Interface(inference, inputs, outputs, title=title, examples=examples, description=description, theme='xiaobaiyuan/theme_brief').launch(
debug=False)