from ultralytics import YOLO from PIL import Image import gradio as gr from huggingface_hub import snapshot_download import os import cv2 import tempfile from tqdm import tqdm import numpy as np import time from typing import List, Dict, Tuple def load_model(repo_id): download_dir = snapshot_download(repo_id) path = os.path.join(download_dir, "best_int8_openvino_model") detection_model = YOLO(path, task='detect') return detection_model def process_images(files): if not files: return None processed_images = [] for file in files: try: if isinstance(file, bytes): from io import BytesIO img = Image.open(BytesIO(file)) else: img = Image.open(file.name) result = detection_model.predict(img, conf=0.5, iou=0.6) img_bgr = result[0].plot() processed = Image.fromarray(img_bgr[..., ::-1]) processed_images.append(processed) except Exception as e: print(f"Error processing image: {e}") continue return processed_images if processed_images else None def draw_boxes(frame: np.ndarray, last_boxes: List[Dict], color_map: Dict[str, Tuple[int, int, int]], thickness: int = 2) -> np.ndarray: output_frame = frame.copy() for box in last_boxes: x1, y1, x2, y2 = map(int, box['bbox']) label = box['class'] color = color_map.get(label, (255, 255, 255)) cv2.rectangle(output_frame, (x1, y1), (x2, y2), color, thickness) conf = box.get('conf', 0.0) label_text = f"{label} {conf:.2f}" cv2.putText(output_frame, label_text, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return output_frame def process_video(video_path, process_all_frames=False): if video_path is None: return None, "No video uploaded" output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) start_frame = 0 end_frame = total_frames process_interval = 1 if process_all_frames else int(fps * 0.25) color_map = { 'headphone': (139, 69, 19), # Dark blue 'microphone': (255, 191, 0) # Light blue } fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) try: frame_count = start_frame last_boxes = [] last_detection_frame = 0 cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) with tqdm(total=end_frame - start_frame, desc="Processing video") as pbar: while cap.isOpened() and frame_count < end_frame: ret, frame = cap.read() if not ret: break if frame_count % process_interval == 0: results = detection_model.predict(frame, conf=0.5, iou=0.6)[0] if len(results.boxes) > 0: last_boxes = [] for box in results.boxes: xyxy = box.xyxy[0].cpu().numpy() conf = float(box.conf) cls = results.names[int(box.cls)] last_boxes.append({ 'bbox': xyxy, 'conf': conf, 'class': cls }) last_detection_frame = frame_count processed_frame = draw_boxes(frame, last_boxes, color_map) else: if frame_count - last_detection_frame > fps * 0.25: last_boxes = [] processed_frame = draw_boxes(frame, last_boxes, color_map) out.write(processed_frame) pbar.update(1) frame_count += 1 finally: cap.release() out.release() return output_path, f"Processed frames {start_frame} to {frame_count}" def get_video_duration(video_path): if not video_path: return 0 cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return frame_count / fps # Load the model REPO_ID = "220502T/headphone-vs-microphone" detection_model = load_model(REPO_ID) with gr.Blocks() as demo: gr.Markdown("# Object Detection - Images and Videos") with gr.Tabs(): with gr.Tab("Image Detection"): with gr.Row(): with gr.Column(): with gr.Group(): upload_image = gr.File( file_count="multiple", label="Upload Image", file_types=["image"], type="binary", height=200 ) gr.Markdown("*You can upload a single image or multiple images at once*") file_preview = gr.Gallery( label="Preview", show_label=True, columns=2, height="300px", # Use px to enforce fixed height allow_preview=True, # Enable preview mode show_share_button=False, # Hide share button for cleaner look container=True, # Enable container mode for scrolling elem_id="preview-gallery" # Add ID for potential custom styling ) def update_preview(files): if not files or len(files) == 0: return None previews = [] for file in files: try: if isinstance(file, bytes): from io import BytesIO img = Image.open(BytesIO(file)) previews.append(img) elif hasattr(file, "name"): img = Image.open(file.name) previews.append(img) except Exception as e: print(f"Error in preview: {e}") continue return previews if previews else None upload_image.change( update_preview, inputs=[upload_image], outputs=[file_preview] ) with gr.Column(): processed_gallery = gr.Gallery( label="Detected Objects", show_label=True, columns=2 ) process_btn = gr.Button("Process Images") process_btn.click( process_images, inputs=[upload_image], outputs=[processed_gallery] ) with gr.Tab("Video Detection"): with gr.Row(): video_input = gr.Video(label="Upload Video") with gr.Row(): with gr.Column(): process_all = gr.Checkbox(label="Process all frames (slower but more accurate)") with gr.Column(): video_output = gr.Video(label="Processed Video") status_text = gr.Markdown("Ready to process") process_btn = gr.Button("Process Video") def process_with_status(video_path, process_all_frames): if not video_path: return None, "No video uploaded" return process_video(video_path, process_all_frames) process_btn.click( process_with_status, inputs=[video_input, process_all], outputs=[video_output, status_text] ) if __name__ == "__main__": demo.launch()