Spaces:
Sleeping
Sleeping
from ultralytics import YOLO | |
from PIL import Image | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import os | |
import cv2 | |
import numpy as np | |
from tqdm import tqdm | |
import tempfile | |
# Function to load the model | |
def load_model(repo_id): | |
"""Download and load the YOLO model.""" | |
download_dir = snapshot_download(repo_id) | |
path = os.path.join(download_dir, "best_int8_openvino_model") | |
detection_model = YOLO(path, task="detect") | |
return detection_model | |
# Function to process an image | |
def predict_image(pilimg, conf_threshold, iou_threshold): | |
"""Process an image with user-defined thresholds.""" | |
try: | |
result = detection_model.predict(pilimg, conf=conf_threshold, iou=iou_threshold) | |
img_bgr = result[0].plot() | |
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert to RGB PIL image | |
return out_pilimg | |
except Exception as e: | |
return f"Error processing image: {e}" | |
# Function to process a video | |
def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_time): | |
"""Process a video and return the path for displaying.""" | |
cap = cv2.VideoCapture(video_file) | |
if not cap.isOpened(): | |
return "Error: Unable to open the video file." | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Use a temporary file to store the processed video | |
temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
output_path = temp_video_file.name | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
start_frame = int(start_time * fps) if start_time else 0 | |
end_frame = int(end_time * fps) if end_time else total_frames | |
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) | |
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) | |
with tqdm(total=end_frame - start_frame, desc="Processing Video") as pbar: | |
while cap.isOpened(): | |
current_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES)) | |
if current_frame >= end_frame: | |
break | |
ret, frame = cap.read() | |
if not ret: | |
break | |
resized_frame = cv2.resize(frame, (640, 640)) # Resize for inference | |
result = detection_model.predict(resized_frame, conf=conf_threshold, iou=iou_threshold) | |
output_frame = result[0].plot() | |
output_frame = cv2.resize(output_frame, (frame_width, frame_height)) # Restore size | |
out.write(output_frame) | |
pbar.update(1) | |
cap.release() | |
out.release() | |
return output_path | |
# Load YOLO model | |
REPO_ID = "Ganrong/107project" | |
detection_model = load_model(REPO_ID) | |
# Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("## Pangolin and Axolotl Detection") | |
# Image Processing Tab | |
with gr.Tab("Image Input"): | |
img_input = gr.Image(type="pil", label="Upload an Image") | |
conf_slider_img = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold") | |
iou_slider_img = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold") | |
img_output = gr.Image(type="pil", label="Processed Image") | |
img_submit = gr.Button("Process Image") | |
img_submit.click( | |
predict_image, | |
inputs=[img_input, conf_slider_img, iou_slider_img], | |
outputs=img_output | |
) | |
# Video Processing Tab | |
with gr.Tab("Video Input"): | |
video_input = gr.Video(label="Upload a Video") | |
conf_slider_video = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold") | |
iou_slider_video = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold") | |
start_time = gr.Number(value=0, label="Start Time (seconds)") | |
end_time = gr.Number(value=0, label="End Time (seconds, 0 for full video)") | |
video_output = gr.Video(label="Processed Video") | |
video_submit = gr.Button("Process Video") | |
video_submit.click( | |
predict_video, | |
inputs=[video_input, conf_slider_video, iou_slider_video, start_time, end_time], | |
outputs=video_output | |
) | |
demo.launch(share=True) | |