Spaces:
Running
on
Zero
Running
on
Zero
| from ultralytics import YOLO | |
| from typing import Any, List, Dict, Optional | |
| import torch | |
| import numpy as np | |
| import os | |
| class DetectionModel: | |
| """Core detection model class for object detection using YOLOv8""" | |
| # Model information dictionary | |
| MODEL_INFO = { | |
| "yolov8n.pt": { | |
| "name": "YOLOv8n (Nano)", | |
| "description": "Fastest model with smallest size (3.2M parameters). Best for speed-critical applications.", | |
| "size_mb": 6, | |
| "inference_speed": "Very Fast" | |
| }, | |
| "yolov8m.pt": { | |
| "name": "YOLOv8m (Medium)", | |
| "description": "Balanced model with good accuracy-speed tradeoff (25.9M parameters). Recommended for general use.", | |
| "size_mb": 25, | |
| "inference_speed": "Medium" | |
| }, | |
| "yolov8x.pt": { | |
| "name": "YOLOv8x (XLarge)", | |
| "description": "Most accurate but slower model (68.2M parameters). Best for accuracy-critical applications.", | |
| "size_mb": 68, | |
| "inference_speed": "Slower" | |
| } | |
| } | |
| def __init__(self, model_name: str = 'yolov8m.pt', confidence: float = 0.25, iou: float = 0.25): | |
| """ | |
| Initialize the detection model | |
| Args: | |
| model_name: Model name or path, default is yolov8m.pt | |
| confidence: Confidence threshold, default is 0.25 | |
| iou: IoU threshold for non-maximum suppression, default is 0.45 | |
| """ | |
| self.model_name = model_name | |
| self.confidence = confidence | |
| self.iou = iou | |
| self.model = None | |
| self.class_names = {} | |
| self.is_model_loaded = False | |
| # Load model on initialization | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the YOLO model""" | |
| try: | |
| print(f"Loading model: {self.model_name}") | |
| self.model = YOLO(self.model_name) | |
| self.class_names = self.model.names | |
| self.is_model_loaded = True | |
| print(f"Successfully loaded model: {self.model_name}") | |
| print(f"Number of classes the model can recognize: {len(self.class_names)}") | |
| except Exception as e: | |
| print(f"Error occurred when loading the model: {e}") | |
| self.is_model_loaded = False | |
| def change_model(self, new_model_name: str) -> bool: | |
| """ | |
| Change the currently loaded model | |
| Args: | |
| new_model_name: Name of the new model to load | |
| Returns: | |
| bool: True if model changed successfully, False otherwise | |
| """ | |
| if self.model_name == new_model_name and self.is_model_loaded: | |
| print(f"Model {new_model_name} is already loaded") | |
| return True | |
| print(f"Changing model from {self.model_name} to {new_model_name}") | |
| # Unload current model to free memory | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| # Clean GPU memory if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Update model name and load new model | |
| self.model_name = new_model_name | |
| self._load_model() | |
| return self.is_model_loaded | |
| def reload_model(self): | |
| """Reload the model (useful for changing model or after error)""" | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| # Clean GPU memory if available | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| self._load_model() | |
| def detect(self, image_input: Any) -> Optional[Any]: | |
| """ | |
| Perform object detection on a single image | |
| Args: | |
| image_input: Image path (str), PIL Image, or numpy array | |
| Returns: | |
| Detection result object or None if error occurred | |
| """ | |
| if self.model is None or not self.is_model_loaded: | |
| print("Model not found or not loaded. Attempting to reload...") | |
| self._load_model() | |
| if self.model is None or not self.is_model_loaded: | |
| print("Failed to load model. Cannot perform detection.") | |
| return None | |
| try: | |
| results = self.model(image_input, conf=self.confidence, iou=self.iou) | |
| return results[0] | |
| except Exception as e: | |
| print(f"Error occurred during detection: {e}") | |
| return None | |
| def get_class_names(self, class_id: int) -> str: | |
| """Get class name for a given class ID""" | |
| return self.class_names.get(class_id, "Unknown Class") | |
| def get_supported_classes(self) -> Dict[int, str]: | |
| """Get all supported classes as a dictionary of {id: class_name}""" | |
| return self.class_names | |
| def get_available_models(cls) -> List[Dict]: | |
| """ | |
| Get list of available models with their information | |
| Returns: | |
| List of dictionaries containing model information | |
| """ | |
| models = [] | |
| for model_file, info in cls.MODEL_INFO.items(): | |
| models.append({ | |
| "model_file": model_file, | |
| "name": info["name"], | |
| "description": info["description"], | |
| "size_mb": info["size_mb"], | |
| "inference_speed": info["inference_speed"] | |
| }) | |
| return models | |
| def get_model_description(cls, model_name: str) -> str: | |
| """Get description for a specific model""" | |
| if model_name in cls.MODEL_INFO: | |
| info = cls.MODEL_INFO[model_name] | |
| return f"{info['name']}: {info['description']} (Size: ~{info['size_mb']}MB, Speed: {info['inference_speed']})" | |
| return "Model information not available" | |