Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
from timm import create_model | |
from ultralytics import YOLO | |
import json | |
# 加载分类模型 | |
with open('class_names.json', 'r') as json_file: | |
class_mapping = json.load(json_file) | |
def load_classification_model(model_path): | |
model = create_model('resnet18', pretrained=False, num_classes=len(class_mapping)) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
model.eval() | |
return model | |
classification_model = load_classification_model("res18_nabird555_acc596.pth") | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
return transform(image).unsqueeze(0) | |
def classify_image(image): | |
image = preprocess_image(image) | |
with torch.no_grad(): | |
outputs = classification_model(image) | |
_, predicted_class = torch.max(outputs, 1) | |
predicted_class_idx = predicted_class.item() | |
return class_mapping[str(predicted_class_idx)] | |
# 加载检测模型 | |
detection_model = YOLO("nabird_det_ep3.pt") | |
def detect_and_classify(image: Image.Image): | |
image_np = np.array(image) | |
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
# 检测鸟类 | |
results = detection_model.predict(image_np, save=False) | |
cropped_birds = [] | |
classifications = [] | |
for result in results: | |
for box in result.boxes: | |
xyxy = box.xyxy[0].tolist() # [x1, y1, x2, y2] | |
x1, y1, x2, y2 = map(int, xyxy) | |
# 裁剪鸟类区域 | |
bird_crop = image.crop((x1, y1, x2, y2)) | |
cropped_birds.append(bird_crop) | |
# 识别鸟类 | |
class_name = classify_image(bird_crop) | |
classifications.append({ | |
"bbox": [x1, y1, x2, y2], | |
"class": class_name | |
}) | |
# 在原图上绘制边框和标签 | |
cv2.rectangle(image_np, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) | |
cv2.putText(image_np, class_name, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
# 转为 RGB 格式返回 | |
detected_image = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) | |
detected_image = Image.fromarray(detected_image) | |
return detected_image, classifications | |
# Gradio 接口 | |
interface = gr.Interface( | |
fn=detect_and_classify, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Image(type="pil", label="Detected Image"), | |
gr.JSON(label="Classifications") | |
], | |
title="Bird Detection and Recognition", | |
description="Upload an image to detect birds and classify their species." | |
) | |
if __name__ == "__main__": | |
interface.launch() | |