3879870C / app.py
Ganrong's picture
Update app.py
4a8f9a6 verified
from ultralytics import YOLO
from PIL import Image
import gradio as gr
from huggingface_hub import snapshot_download
import os
import cv2
import numpy as np
from tqdm import tqdm
import tempfile
# Function to load the model
def load_model(repo_id):
"""Download and load the YOLO model."""
download_dir = snapshot_download(repo_id)
path = os.path.join(download_dir, "best_int8_openvino_model")
detection_model = YOLO(path, task="detect")
return detection_model
# Function to process an image
def predict_image(pilimg, conf_threshold, iou_threshold):
"""Process an image with user-defined thresholds."""
try:
result = detection_model.predict(pilimg, conf=conf_threshold, iou=iou_threshold)
img_bgr = result[0].plot()
out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert to RGB PIL image
return out_pilimg
except Exception as e:
return f"Error processing image: {e}"
# Function to process a video
def predict_video(video_file, conf_threshold, iou_threshold, start_time, end_time):
"""Process a video and return the path for displaying."""
cap = cv2.VideoCapture(video_file)
if not cap.isOpened():
return "Error: Unable to open the video file."
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Use a temporary file to store the processed video
temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
output_path = temp_video_file.name
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
start_frame = int(start_time * fps) if start_time else 0
end_frame = int(end_time * fps) if end_time else total_frames
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
with tqdm(total=end_frame - start_frame, desc="Processing Video") as pbar:
while cap.isOpened():
current_frame = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
if current_frame >= end_frame:
break
ret, frame = cap.read()
if not ret:
break
resized_frame = cv2.resize(frame, (640, 640)) # Resize for inference
result = detection_model.predict(resized_frame, conf=conf_threshold, iou=iou_threshold)
output_frame = result[0].plot()
output_frame = cv2.resize(output_frame, (frame_width, frame_height)) # Restore size
out.write(output_frame)
pbar.update(1)
cap.release()
out.release()
return output_path
# Load YOLO model
REPO_ID = "Ganrong/107project"
detection_model = load_model(REPO_ID)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Pangolin and Axolotl Detection")
# Image Processing Tab
with gr.Tab("Image Input"):
img_input = gr.Image(type="pil", label="Upload an Image")
conf_slider_img = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
iou_slider_img = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
img_output = gr.Image(type="pil", label="Processed Image")
img_submit = gr.Button("Process Image")
img_submit.click(
predict_image,
inputs=[img_input, conf_slider_img, iou_slider_img],
outputs=img_output
)
# Video Processing Tab
with gr.Tab("Video Input"):
video_input = gr.Video(label="Upload a Video")
conf_slider_video = gr.Slider(0.1, 1.0, value=0.5, step=0.05, label="Confidence Threshold")
iou_slider_video = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="IoU Threshold")
start_time = gr.Number(value=0, label="Start Time (seconds)")
end_time = gr.Number(value=0, label="End Time (seconds, 0 for full video)")
video_output = gr.Video(label="Processed Video")
video_submit = gr.Button("Process Video")
video_submit.click(
predict_video,
inputs=[video_input, conf_slider_video, iou_slider_video, start_time, end_time],
outputs=video_output
)
demo.launch(share=True)