bird_det_cls / app.py
root
init app
5760d9c
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import cv2
from timm import create_model
from ultralytics import YOLO
import json
# 加载分类模型
with open('class_names.json', 'r') as json_file:
class_mapping = json.load(json_file)
def load_classification_model(model_path):
model = create_model('resnet18', pretrained=False, num_classes=len(class_mapping))
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
classification_model = load_classification_model("res18_nabird555_acc596.pth")
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
return transform(image).unsqueeze(0)
def classify_image(image):
image = preprocess_image(image)
with torch.no_grad():
outputs = classification_model(image)
_, predicted_class = torch.max(outputs, 1)
predicted_class_idx = predicted_class.item()
return class_mapping[str(predicted_class_idx)]
# 加载检测模型
detection_model = YOLO("nabird_det_ep3.pt")
def detect_and_classify(image: Image.Image):
image_np = np.array(image)
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
# 检测鸟类
results = detection_model.predict(image_np, save=False)
cropped_birds = []
classifications = []
for result in results:
for box in result.boxes:
xyxy = box.xyxy[0].tolist() # [x1, y1, x2, y2]
x1, y1, x2, y2 = map(int, xyxy)
# 裁剪鸟类区域
bird_crop = image.crop((x1, y1, x2, y2))
cropped_birds.append(bird_crop)
# 识别鸟类
class_name = classify_image(bird_crop)
classifications.append({
"bbox": [x1, y1, x2, y2],
"class": class_name
})
# 在原图上绘制边框和标签
cv2.rectangle(image_np, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# 转为 RGB 格式返回
detected_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
detected_image = Image.fromarray(detected_image)
return detected_image, classifications
# Gradio 接口
interface = gr.Interface(
fn=detect_and_classify,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detected Image"),
gr.JSON(label="Classifications")
],
title="Bird Detection and Recognition",
description="Upload an image to detect birds and classify their species."
)
if __name__ == "__main__":
interface.launch()