bird_det / app.py
heuue's picture
Update app.py
55090a4 verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from ultralytics import YOLO
import cv2
import numpy as np
# 加载模型
model = YOLO("nabird_det_ep3.pt")
def detect_objects(image: Image.Image):
image_np = np.array(image)
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # 转为 BGR 格式供 OpenCV 使用
results = model.predict(image_np, save=False)
for result in results:
for box in result.boxes:
# 提取 bbox 和类别信息
xyxy = box.xyxy[0].tolist() # [x1, y1, x2, y2]
conf = box.conf[0].item() # 置信度
cls = int(box.cls[0].item()) # 类别索引
class_name = model.names[cls] # 获取类别名称
# 转换坐标为整数
x1, y1, x2, y2 = map(int, xyxy)
# 绘制矩形框
cv2.rectangle(image_np, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
# 绘制类别名称和置信度
label = f"{class_name} {conf:.2f}"
cv2.putText(image_np, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# 转回 RGB 格式并返回 PIL Image
image_with_boxes = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
return Image.fromarray(image_with_boxes)
example_image_paths = ["PC060715.jpg", "PC061030.jpg", "PC060806.jpg"] # 替换为你的图片路径
app = gr.Interface(
fn=detect_objects, # 检测函数
inputs=gr.Image(type="pil"), # 接收输入图片,返回 PIL 格式
outputs=gr.Image(type="pil"), # 返回叠框后的图片
examples=example_image_paths,
title="YOLO Object Detection",
description="Upload an image to detect objects using YOLO."
)
app.launch()