Abs6187's picture
Update config.py
acdcf49 verified
"""
Vehicle Detection Configuration Module
=======================================
Manages configuration settings for vehicle detection, tracking, and speed estimation.
Authors:
- Abhay Gupta (0205CC221005)
- Aditi Lakhera (0205CC221011)
- Balraj Patel (0205CC221049)
- Bhumika Patel (0205CC221050)
"""
import os
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import logging
logger = logging.getLogger(__name__)
@dataclass
class VehicleDetectionConfig:
"""
Configuration class for vehicle detection and speed estimation system.
This class encapsulates all configuration parameters needed for the
vehicle detection pipeline, including video paths, model settings,
detection zones, and perspective transformation parameters.
"""
# Video Configuration
input_video: str = "./data/vehicles.mp4"
output_video: str = "./data/vehicles_output.mp4"
# Model Configuration
model_name: str = "yolov8n"
model_path: Optional[str] = None
confidence_threshold: float = 0.3
iou_threshold: float = 0.7
# Detection Zone Configuration
line_y: int = 480
line_offset: int = 55
crossing_threshold: int = 1
# Perspective Transformation Configuration
# Source points define the region in the original video frame
source_points: List[List[int]] = field(default_factory=lambda: [
[450, 300], # Top-left
[860, 300], # Top-right
[1900, 720], # Bottom-right
[-660, 720] # Bottom-left
])
# Target points define the transformed top-down view dimensions (in meters)
target_width_meters: float = 25.0
target_height_meters: float = 100.0
# Display Configuration (disabled by default for headless environments like HF Spaces)
window_name: str = "Vehicle Speed Estimation - Traffic Analysis"
display_enabled: bool = False
# Annotation Configuration
enable_boxes: bool = True
enable_labels: bool = True
enable_traces: bool = True
enable_line_zones: bool = True
trace_length: int = 20
# Speed Estimation Configuration
speed_history_seconds: int = 1
speed_unit: str = "km/h" # Options: "km/h", "mph", "m/s"
def __post_init__(self):
"""Validate configuration after initialization."""
self._validate_config()
self._setup_model_path()
def _validate_config(self) -> None:
"""
Validate configuration parameters.
Raises:
ValueError: If configuration parameters are invalid
"""
# Validate video paths
if not self.input_video:
raise ValueError("Input video path cannot be empty")
# Validate model configuration
if not 0.0 <= self.confidence_threshold <= 1.0:
raise ValueError(f"Confidence threshold must be between 0 and 1, got {self.confidence_threshold}")
if not 0.0 <= self.iou_threshold <= 1.0:
raise ValueError(f"IOU threshold must be between 0 and 1, got {self.iou_threshold}")
# Validate detection zone
if self.line_y < 0:
raise ValueError(f"Line Y position must be positive, got {self.line_y}")
if self.line_offset < 0:
raise ValueError(f"Line offset must be positive, got {self.line_offset}")
# Validate perspective transformation
if len(self.source_points) != 4:
raise ValueError(f"Source points must contain exactly 4 points, got {len(self.source_points)}")
for i, point in enumerate(self.source_points):
if len(point) != 2:
raise ValueError(f"Source point {i} must have 2 coordinates, got {len(point)}")
if self.target_width_meters <= 0 or self.target_height_meters <= 0:
raise ValueError("Target dimensions must be positive")
# Validate speed configuration
if self.speed_unit not in ["km/h", "mph", "m/s"]:
raise ValueError(f"Invalid speed unit: {self.speed_unit}. Must be 'km/h', 'mph', or 'm/s'")
logger.info("Configuration validation successful")
def _setup_model_path(self) -> None:
"""Set up the model path based on model name."""
if self.model_path is None:
# Try to find model in models directory
model_dir = "./models"
potential_paths = [
f"{model_dir}/{self.model_name}.pt",
f"{model_dir}/VisDrone_YOLO_x2.pt", # Custom trained model
self.model_name # Let ultralytics download from hub
]
for path in potential_paths:
if os.path.exists(path):
self.model_path = path
logger.info(f"Using model from: {path}")
return
# Use model name directly (will be downloaded by ultralytics)
self.model_path = self.model_name
logger.info(f"Model will be downloaded: {self.model_name}")
@property
def target_points(self) -> List[List[float]]:
"""
Generate target points for perspective transformation.
Returns:
List of 4 points defining the target perspective in meters
"""
w, h = self.target_width_meters, self.target_height_meters
return [
[0, 0], # Top-left
[w, 0], # Top-right
[w, h], # Bottom-right
[0, h] # Bottom-left
]
def get_speed_conversion_factor(self) -> float:
"""
Get conversion factor for speed unit.
Returns:
Conversion factor from m/s to desired unit
"""
conversions = {
"km/h": 3.6,
"mph": 2.23694,
"m/s": 1.0
}
return conversions[self.speed_unit]
def to_dict(self) -> dict:
"""
Convert configuration to dictionary.
Returns:
Dictionary representation of configuration
"""
return {
"input_video": self.input_video,
"output_video": self.output_video,
"model_name": self.model_name,
"model_path": self.model_path,
"confidence_threshold": self.confidence_threshold,
"line_y": self.line_y,
"speed_unit": self.speed_unit,
}
def __repr__(self) -> str:
"""String representation of configuration."""
return f"VehicleDetectionConfig(model={self.model_name}, input={self.input_video})"
# Default configuration instance for backward compatibility
DEFAULT_CONFIG = VehicleDetectionConfig()
# Export commonly used configuration values
IN_VIDEO_PATH = DEFAULT_CONFIG.input_video
OUT_VIDEO_PATH = DEFAULT_CONFIG.output_video
YOLO_MODEL_PATH = DEFAULT_CONFIG.model_path
LINE_Y = DEFAULT_CONFIG.line_y
SOURCE_POINTS = DEFAULT_CONFIG.source_points
TARGET_POINTS = DEFAULT_CONFIG.target_points
WINDOW_NAME = DEFAULT_CONFIG.window_name