Abs6187 commited on
Commit
acdcf49
·
verified ·
1 Parent(s): 82ae22b

Update config.py

Browse files
Files changed (1) hide show
  1. config.py +205 -205
config.py CHANGED
@@ -1,205 +1,205 @@
1
- """
2
- Vehicle Detection Configuration Module
3
- =======================================
4
-
5
- Manages configuration settings for vehicle detection, tracking, and speed estimation.
6
-
7
- Authors:
8
- - Abhay Gupta (0205CC221005)
9
- - Aditi Lakhera (0205CC221011)
10
- - Balraj Patel (0205CC221049)
11
- - Bhumika Patel (0205CC221050)
12
- """
13
-
14
- import os
15
- from dataclasses import dataclass, field
16
- from typing import List, Tuple, Optional
17
- import logging
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- @dataclass
23
- class VehicleDetectionConfig:
24
- """
25
- Configuration class for vehicle detection and speed estimation system.
26
-
27
- This class encapsulates all configuration parameters needed for the
28
- vehicle detection pipeline, including video paths, model settings,
29
- detection zones, and perspective transformation parameters.
30
- """
31
-
32
- # Video Configuration
33
- input_video: str = "./data/vehicles.mp4"
34
- output_video: str = "./data/vehicles_output.mp4"
35
-
36
- # Model Configuration
37
- model_name: str = "yolov8n"
38
- model_path: Optional[str] = None
39
- confidence_threshold: float = 0.3
40
- iou_threshold: float = 0.7
41
-
42
- # Detection Zone Configuration
43
- line_y: int = 480
44
- line_offset: int = 55
45
- crossing_threshold: int = 1
46
-
47
- # Perspective Transformation Configuration
48
- # Source points define the region in the original video frame
49
- source_points: List[List[int]] = field(default_factory=lambda: [
50
- [450, 300], # Top-left
51
- [860, 300], # Top-right
52
- [1900, 720], # Bottom-right
53
- [-660, 720] # Bottom-left
54
- ])
55
-
56
- # Target points define the transformed top-down view dimensions (in meters)
57
- target_width_meters: float = 25.0
58
- target_height_meters: float = 100.0
59
-
60
- # Display Configuration
61
- window_name: str = "Vehicle Speed Estimation - Traffic Analysis"
62
- display_enabled: bool = True
63
-
64
- # Annotation Configuration
65
- enable_boxes: bool = True
66
- enable_labels: bool = True
67
- enable_traces: bool = True
68
- enable_line_zones: bool = True
69
- trace_length: int = 20
70
-
71
- # Speed Estimation Configuration
72
- speed_history_seconds: int = 1
73
- speed_unit: str = "km/h" # Options: "km/h", "mph", "m/s"
74
-
75
- def __post_init__(self):
76
- """Validate configuration after initialization."""
77
- self._validate_config()
78
- self._setup_model_path()
79
-
80
- def _validate_config(self) -> None:
81
- """
82
- Validate configuration parameters.
83
-
84
- Raises:
85
- ValueError: If configuration parameters are invalid
86
- """
87
- # Validate video paths
88
- if not self.input_video:
89
- raise ValueError("Input video path cannot be empty")
90
-
91
- # Validate model configuration
92
- if not 0.0 <= self.confidence_threshold <= 1.0:
93
- raise ValueError(f"Confidence threshold must be between 0 and 1, got {self.confidence_threshold}")
94
-
95
- if not 0.0 <= self.iou_threshold <= 1.0:
96
- raise ValueError(f"IOU threshold must be between 0 and 1, got {self.iou_threshold}")
97
-
98
- # Validate detection zone
99
- if self.line_y < 0:
100
- raise ValueError(f"Line Y position must be positive, got {self.line_y}")
101
-
102
- if self.line_offset < 0:
103
- raise ValueError(f"Line offset must be positive, got {self.line_offset}")
104
-
105
- # Validate perspective transformation
106
- if len(self.source_points) != 4:
107
- raise ValueError(f"Source points must contain exactly 4 points, got {len(self.source_points)}")
108
-
109
- for i, point in enumerate(self.source_points):
110
- if len(point) != 2:
111
- raise ValueError(f"Source point {i} must have 2 coordinates, got {len(point)}")
112
-
113
- if self.target_width_meters <= 0 or self.target_height_meters <= 0:
114
- raise ValueError("Target dimensions must be positive")
115
-
116
- # Validate speed configuration
117
- if self.speed_unit not in ["km/h", "mph", "m/s"]:
118
- raise ValueError(f"Invalid speed unit: {self.speed_unit}. Must be 'km/h', 'mph', or 'm/s'")
119
-
120
- logger.info("Configuration validation successful")
121
-
122
- def _setup_model_path(self) -> None:
123
- """Set up the model path based on model name."""
124
- if self.model_path is None:
125
- # Try to find model in models directory
126
- model_dir = "./models"
127
- potential_paths = [
128
- f"{model_dir}/{self.model_name}.pt",
129
- f"{model_dir}/VisDrone_YOLO_x2.pt", # Custom trained model
130
- self.model_name # Let ultralytics download from hub
131
- ]
132
-
133
- for path in potential_paths:
134
- if os.path.exists(path):
135
- self.model_path = path
136
- logger.info(f"Using model from: {path}")
137
- return
138
-
139
- # Use model name directly (will be downloaded by ultralytics)
140
- self.model_path = self.model_name
141
- logger.info(f"Model will be downloaded: {self.model_name}")
142
-
143
- @property
144
- def target_points(self) -> List[List[float]]:
145
- """
146
- Generate target points for perspective transformation.
147
-
148
- Returns:
149
- List of 4 points defining the target perspective in meters
150
- """
151
- w, h = self.target_width_meters, self.target_height_meters
152
- return [
153
- [0, 0], # Top-left
154
- [w, 0], # Top-right
155
- [w, h], # Bottom-right
156
- [0, h] # Bottom-left
157
- ]
158
-
159
- def get_speed_conversion_factor(self) -> float:
160
- """
161
- Get conversion factor for speed unit.
162
-
163
- Returns:
164
- Conversion factor from m/s to desired unit
165
- """
166
- conversions = {
167
- "km/h": 3.6,
168
- "mph": 2.23694,
169
- "m/s": 1.0
170
- }
171
- return conversions[self.speed_unit]
172
-
173
- def to_dict(self) -> dict:
174
- """
175
- Convert configuration to dictionary.
176
-
177
- Returns:
178
- Dictionary representation of configuration
179
- """
180
- return {
181
- "input_video": self.input_video,
182
- "output_video": self.output_video,
183
- "model_name": self.model_name,
184
- "model_path": self.model_path,
185
- "confidence_threshold": self.confidence_threshold,
186
- "line_y": self.line_y,
187
- "speed_unit": self.speed_unit,
188
- }
189
-
190
- def __repr__(self) -> str:
191
- """String representation of configuration."""
192
- return f"VehicleDetectionConfig(model={self.model_name}, input={self.input_video})"
193
-
194
-
195
- # Default configuration instance for backward compatibility
196
- DEFAULT_CONFIG = VehicleDetectionConfig()
197
-
198
- # Export commonly used configuration values
199
- IN_VIDEO_PATH = DEFAULT_CONFIG.input_video
200
- OUT_VIDEO_PATH = DEFAULT_CONFIG.output_video
201
- YOLO_MODEL_PATH = DEFAULT_CONFIG.model_path
202
- LINE_Y = DEFAULT_CONFIG.line_y
203
- SOURCE_POINTS = DEFAULT_CONFIG.source_points
204
- TARGET_POINTS = DEFAULT_CONFIG.target_points
205
- WINDOW_NAME = DEFAULT_CONFIG.window_name
 
1
+ """
2
+ Vehicle Detection Configuration Module
3
+ =======================================
4
+
5
+ Manages configuration settings for vehicle detection, tracking, and speed estimation.
6
+
7
+ Authors:
8
+ - Abhay Gupta (0205CC221005)
9
+ - Aditi Lakhera (0205CC221011)
10
+ - Balraj Patel (0205CC221049)
11
+ - Bhumika Patel (0205CC221050)
12
+ """
13
+
14
+ import os
15
+ from dataclasses import dataclass, field
16
+ from typing import List, Tuple, Optional
17
+ import logging
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class VehicleDetectionConfig:
24
+ """
25
+ Configuration class for vehicle detection and speed estimation system.
26
+
27
+ This class encapsulates all configuration parameters needed for the
28
+ vehicle detection pipeline, including video paths, model settings,
29
+ detection zones, and perspective transformation parameters.
30
+ """
31
+
32
+ # Video Configuration
33
+ input_video: str = "./data/vehicles.mp4"
34
+ output_video: str = "./data/vehicles_output.mp4"
35
+
36
+ # Model Configuration
37
+ model_name: str = "yolov8n"
38
+ model_path: Optional[str] = None
39
+ confidence_threshold: float = 0.3
40
+ iou_threshold: float = 0.7
41
+
42
+ # Detection Zone Configuration
43
+ line_y: int = 480
44
+ line_offset: int = 55
45
+ crossing_threshold: int = 1
46
+
47
+ # Perspective Transformation Configuration
48
+ # Source points define the region in the original video frame
49
+ source_points: List[List[int]] = field(default_factory=lambda: [
50
+ [450, 300], # Top-left
51
+ [860, 300], # Top-right
52
+ [1900, 720], # Bottom-right
53
+ [-660, 720] # Bottom-left
54
+ ])
55
+
56
+ # Target points define the transformed top-down view dimensions (in meters)
57
+ target_width_meters: float = 25.0
58
+ target_height_meters: float = 100.0
59
+
60
+ # Display Configuration (disabled by default for headless environments like HF Spaces)
61
+ window_name: str = "Vehicle Speed Estimation - Traffic Analysis"
62
+ display_enabled: bool = False
63
+
64
+ # Annotation Configuration
65
+ enable_boxes: bool = True
66
+ enable_labels: bool = True
67
+ enable_traces: bool = True
68
+ enable_line_zones: bool = True
69
+ trace_length: int = 20
70
+
71
+ # Speed Estimation Configuration
72
+ speed_history_seconds: int = 1
73
+ speed_unit: str = "km/h" # Options: "km/h", "mph", "m/s"
74
+
75
+ def __post_init__(self):
76
+ """Validate configuration after initialization."""
77
+ self._validate_config()
78
+ self._setup_model_path()
79
+
80
+ def _validate_config(self) -> None:
81
+ """
82
+ Validate configuration parameters.
83
+
84
+ Raises:
85
+ ValueError: If configuration parameters are invalid
86
+ """
87
+ # Validate video paths
88
+ if not self.input_video:
89
+ raise ValueError("Input video path cannot be empty")
90
+
91
+ # Validate model configuration
92
+ if not 0.0 <= self.confidence_threshold <= 1.0:
93
+ raise ValueError(f"Confidence threshold must be between 0 and 1, got {self.confidence_threshold}")
94
+
95
+ if not 0.0 <= self.iou_threshold <= 1.0:
96
+ raise ValueError(f"IOU threshold must be between 0 and 1, got {self.iou_threshold}")
97
+
98
+ # Validate detection zone
99
+ if self.line_y < 0:
100
+ raise ValueError(f"Line Y position must be positive, got {self.line_y}")
101
+
102
+ if self.line_offset < 0:
103
+ raise ValueError(f"Line offset must be positive, got {self.line_offset}")
104
+
105
+ # Validate perspective transformation
106
+ if len(self.source_points) != 4:
107
+ raise ValueError(f"Source points must contain exactly 4 points, got {len(self.source_points)}")
108
+
109
+ for i, point in enumerate(self.source_points):
110
+ if len(point) != 2:
111
+ raise ValueError(f"Source point {i} must have 2 coordinates, got {len(point)}")
112
+
113
+ if self.target_width_meters <= 0 or self.target_height_meters <= 0:
114
+ raise ValueError("Target dimensions must be positive")
115
+
116
+ # Validate speed configuration
117
+ if self.speed_unit not in ["km/h", "mph", "m/s"]:
118
+ raise ValueError(f"Invalid speed unit: {self.speed_unit}. Must be 'km/h', 'mph', or 'm/s'")
119
+
120
+ logger.info("Configuration validation successful")
121
+
122
+ def _setup_model_path(self) -> None:
123
+ """Set up the model path based on model name."""
124
+ if self.model_path is None:
125
+ # Try to find model in models directory
126
+ model_dir = "./models"
127
+ potential_paths = [
128
+ f"{model_dir}/{self.model_name}.pt",
129
+ f"{model_dir}/VisDrone_YOLO_x2.pt", # Custom trained model
130
+ self.model_name # Let ultralytics download from hub
131
+ ]
132
+
133
+ for path in potential_paths:
134
+ if os.path.exists(path):
135
+ self.model_path = path
136
+ logger.info(f"Using model from: {path}")
137
+ return
138
+
139
+ # Use model name directly (will be downloaded by ultralytics)
140
+ self.model_path = self.model_name
141
+ logger.info(f"Model will be downloaded: {self.model_name}")
142
+
143
+ @property
144
+ def target_points(self) -> List[List[float]]:
145
+ """
146
+ Generate target points for perspective transformation.
147
+
148
+ Returns:
149
+ List of 4 points defining the target perspective in meters
150
+ """
151
+ w, h = self.target_width_meters, self.target_height_meters
152
+ return [
153
+ [0, 0], # Top-left
154
+ [w, 0], # Top-right
155
+ [w, h], # Bottom-right
156
+ [0, h] # Bottom-left
157
+ ]
158
+
159
+ def get_speed_conversion_factor(self) -> float:
160
+ """
161
+ Get conversion factor for speed unit.
162
+
163
+ Returns:
164
+ Conversion factor from m/s to desired unit
165
+ """
166
+ conversions = {
167
+ "km/h": 3.6,
168
+ "mph": 2.23694,
169
+ "m/s": 1.0
170
+ }
171
+ return conversions[self.speed_unit]
172
+
173
+ def to_dict(self) -> dict:
174
+ """
175
+ Convert configuration to dictionary.
176
+
177
+ Returns:
178
+ Dictionary representation of configuration
179
+ """
180
+ return {
181
+ "input_video": self.input_video,
182
+ "output_video": self.output_video,
183
+ "model_name": self.model_name,
184
+ "model_path": self.model_path,
185
+ "confidence_threshold": self.confidence_threshold,
186
+ "line_y": self.line_y,
187
+ "speed_unit": self.speed_unit,
188
+ }
189
+
190
+ def __repr__(self) -> str:
191
+ """String representation of configuration."""
192
+ return f"VehicleDetectionConfig(model={self.model_name}, input={self.input_video})"
193
+
194
+
195
+ # Default configuration instance for backward compatibility
196
+ DEFAULT_CONFIG = VehicleDetectionConfig()
197
+
198
+ # Export commonly used configuration values
199
+ IN_VIDEO_PATH = DEFAULT_CONFIG.input_video
200
+ OUT_VIDEO_PATH = DEFAULT_CONFIG.output_video
201
+ YOLO_MODEL_PATH = DEFAULT_CONFIG.model_path
202
+ LINE_Y = DEFAULT_CONFIG.line_y
203
+ SOURCE_POINTS = DEFAULT_CONFIG.source_points
204
+ TARGET_POINTS = DEFAULT_CONFIG.target_points
205
+ WINDOW_NAME = DEFAULT_CONFIG.window_name