Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
helper.py
CHANGED
@@ -207,3 +207,17 @@ def make_table_from_dict(obj: dict, selected_key: str) -> list:
|
|
207 |
table.append({"name": k, "value": v, "selected": False})
|
208 |
|
209 |
return table
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
table.append({"name": k, "value": v, "selected": False})
|
208 |
|
209 |
return table
|
210 |
+
|
211 |
+
|
212 |
+
def make_table_from_dict_multiselect(
|
213 |
+
obj: dict, selected_vals: list[int]
|
214 |
+
) -> list:
|
215 |
+
table = []
|
216 |
+
for k, v in obj.items():
|
217 |
+
if v in selected_vals:
|
218 |
+
# print(k, v, selected_key)
|
219 |
+
table.append({"name": k, "value": v, "selected": True})
|
220 |
+
else:
|
221 |
+
table.append({"name": k, "value": v, "selected": False})
|
222 |
+
|
223 |
+
return table
|
webapp.py
CHANGED
@@ -5,6 +5,7 @@ import cv2
|
|
5 |
import numpy as np
|
6 |
from pathlib import Path
|
7 |
import torch
|
|
|
8 |
from starlette.middleware import Middleware
|
9 |
from starlette.responses import StreamingResponse, Response
|
10 |
from starlette.requests import Request
|
@@ -19,7 +20,9 @@ from ultralytics_solutions_modified import object_counter, speed_estimation
|
|
19 |
from vidgear.gears import CamGear
|
20 |
from vidgear.gears.asyncio import WebGear
|
21 |
from vidgear.gears.asyncio.helper import reducer
|
22 |
-
from helper import
|
|
|
|
|
23 |
|
24 |
|
25 |
HERE = Path(__file__).parent
|
@@ -31,7 +34,7 @@ RETRY_TIMEOUT_MILSEC = 15000 # milisecond
|
|
31 |
# Create and configure logger
|
32 |
# logger = logging.getLogger(__name__).addHandler(logging.NullHandler())
|
33 |
logging.basicConfig(
|
34 |
-
format='%(asctime)s %(name)-8s->%(module)-20s->%(funcName)-20s:%(lineno)-4s::%(levelname)-8s %(message)s',
|
35 |
level=logging.INFO
|
36 |
)
|
37 |
|
@@ -55,6 +58,7 @@ class DemoCase:
|
|
55 |
"y8huge": "./data/models/yolov8x.pt",
|
56 |
}
|
57 |
self.model_choice_default: str = "y8small"
|
|
|
58 |
# predefined youtube live stream urls
|
59 |
self.url_dict: dict = {
|
60 |
"Peace Bridge US": "https://youtu.be/9En2186vo5g",
|
@@ -67,26 +71,46 @@ class DemoCase:
|
|
67 |
"Port Everglades-2": "https://youtu.be/Nhuu1QsW5LI",
|
68 |
"Port Everglades-3": "https://youtu.be/Lpm-C_Gz6yM",
|
69 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
self.cam_loc_default: str = "Peace Bridge US"
|
71 |
-
# run time parameters that are from user input
|
72 |
-
self.model_choice: str = self.model_choice_default
|
73 |
self.cam_loc: str = self.cam_loc_default
|
74 |
-
self.
|
75 |
-
|
76 |
-
|
77 |
-
self.
|
78 |
-
|
79 |
-
self.
|
80 |
-
self.
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
85 |
self.stream0: CamGear = None
|
86 |
self.stream1: CamGear = None
|
87 |
self.counter = None
|
88 |
self.speed_obj = None
|
89 |
|
|
|
|
|
|
|
|
|
|
|
90 |
def load_model(
|
91 |
self,
|
92 |
model_choice: str = "y8small",
|
@@ -123,10 +147,10 @@ class DemoCase:
|
|
123 |
)
|
124 |
|
125 |
# setup some configs
|
126 |
-
self.conf_threshold
|
127 |
-
self.iou_threshold
|
128 |
-
self.use_FP16
|
129 |
-
self.use_stream_buffer
|
130 |
logging.info(
|
131 |
f"{self.model_choice}: conf={self.conf_threshold:.2f} | "
|
132 |
f"iou={self.iou_threshold:.2f} | FP16={self.use_FP16} | "
|
@@ -144,9 +168,9 @@ class DemoCase:
|
|
144 |
if (bool(cam_loc_key) is False or bool(cam_loc_val) is False):
|
145 |
self.cam_loc = self.cam_loc_default
|
146 |
logging.warning(
|
147 |
-
f
|
148 |
-
f
|
149 |
-
f
|
150 |
)
|
151 |
elif cam_loc_key not in self.url_dict:
|
152 |
if try_site(self.url_dict[self.cam_loc]):
|
@@ -171,8 +195,40 @@ class DemoCase:
|
|
171 |
f'use {{{self.cam_loc}: {self.url_dict[self.cam_loc]}}} as source'
|
172 |
)
|
173 |
|
174 |
-
def
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
self.roi_height = int(self.FRAME_HEIGHT / 2)
|
177 |
logging.warning(
|
178 |
f'roi_height invalid, use default {int(self.FRAME_HEIGHT / 2)}'
|
@@ -181,12 +237,33 @@ class DemoCase:
|
|
181 |
self.roi_height = roi_height
|
182 |
logging.info(f'roi_height is set at {self.roi_height}')
|
183 |
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
if (frame_reduction < 0 or frame_reduction > 100):
|
186 |
-
self.frame_reduction =
|
187 |
logging.warning(
|
188 |
f'frame_reduction:{frame_reduction} invalid, '
|
189 |
-
f'use default value
|
190 |
)
|
191 |
else:
|
192 |
self.frame_reduction = frame_reduction
|
@@ -222,8 +299,10 @@ class DemoCase:
|
|
222 |
# do something with your OpenCV frame here
|
223 |
draw_text(
|
224 |
img=frame,
|
225 |
-
text=datetime.now(
|
226 |
-
|
|
|
|
|
227 |
font=cv2.FONT_HERSHEY_SIMPLEX,
|
228 |
font_scale=1,
|
229 |
font_thickness=2,
|
@@ -231,7 +310,7 @@ class DemoCase:
|
|
231 |
text_color=(0, 255, 255),
|
232 |
text_color_bg=(0, 0, 0),
|
233 |
)
|
234 |
-
# reducer frame size for
|
235 |
frame = await reducer(
|
236 |
frame, percentage=self.frame_reduction
|
237 |
)
|
@@ -245,11 +324,15 @@ class DemoCase:
|
|
245 |
|
246 |
if self.stream0 is not None:
|
247 |
self.stream0.stop()
|
|
|
|
|
248 |
self.stream0 = None
|
249 |
self._is_running = False
|
250 |
except asyncio.CancelledError:
|
251 |
if self.stream0 is not None:
|
252 |
self.stream0.stop()
|
|
|
|
|
253 |
self.stream0 = None
|
254 |
self._is_running = False
|
255 |
logging.warning(
|
@@ -308,8 +391,8 @@ class DemoCase:
|
|
308 |
logging=True
|
309 |
).start()
|
310 |
|
311 |
-
if self._is_tracking:
|
312 |
-
if self.counter is None:
|
313 |
# setup object counter & speed estimator
|
314 |
region_points = [
|
315 |
(5, -self.roi_thickness_half + self.roi_height),
|
@@ -318,7 +401,7 @@ class DemoCase:
|
|
318 |
self.FRAME_WIDTH - 5,
|
319 |
self.roi_thickness_half + self.roi_height
|
320 |
),
|
321 |
-
(
|
322 |
self.FRAME_WIDTH - 5,
|
323 |
-self.roi_thickness_half + self.roi_height
|
324 |
),
|
@@ -332,8 +415,9 @@ class DemoCase:
|
|
332 |
draw_boxes=False,
|
333 |
draw_reg_pts=True,
|
334 |
)
|
|
|
335 |
|
336 |
-
if self.speed_obj is None:
|
337 |
# Init speed estimator
|
338 |
line_points = [
|
339 |
(5, self.roi_height),
|
@@ -345,9 +429,46 @@ class DemoCase:
|
|
345 |
names=self.model.names,
|
346 |
view_img=False
|
347 |
)
|
|
|
348 |
|
349 |
try:
|
350 |
while (self.stream1 is not None and self._is_running):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
# read frame from provided source
|
352 |
frame = self.stream1.read()
|
353 |
if frame is None:
|
@@ -359,8 +480,10 @@ class DemoCase:
|
|
359 |
# do something with your OpenCV frame here
|
360 |
draw_text(
|
361 |
img=frame,
|
362 |
-
text=datetime.now(
|
363 |
-
|
|
|
|
|
364 |
font=cv2.FONT_HERSHEY_SIMPLEX,
|
365 |
font_scale=1,
|
366 |
font_thickness=2,
|
@@ -370,7 +493,12 @@ class DemoCase:
|
|
370 |
)
|
371 |
|
372 |
frame_tagged = frame
|
373 |
-
if
|
|
|
|
|
|
|
|
|
|
|
374 |
# YOLOv8 tracking, persisting tracks between frames
|
375 |
results = self.model.track(
|
376 |
source=frame,
|
@@ -408,6 +536,8 @@ class DemoCase:
|
|
408 |
|
409 |
if self.stream1 is not None:
|
410 |
self.stream1.stop()
|
|
|
|
|
411 |
self.stream1 = None
|
412 |
self._is_tracking = False
|
413 |
self._is_running = False
|
@@ -415,12 +545,30 @@ class DemoCase:
|
|
415 |
except asyncio.CancelledError:
|
416 |
if self.stream1 is not None:
|
417 |
self.stream1.stop()
|
|
|
|
|
418 |
self.stream1 = None
|
419 |
self._is_tracking = False
|
420 |
self._is_running = False
|
421 |
logging.warning(
|
422 |
"client disconnected in frame1_producer"
|
423 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
else:
|
425 |
if self._is_running is True:
|
426 |
pass
|
@@ -496,6 +644,29 @@ class DemoCase:
|
|
496 |
await asyncio.sleep(0.01)
|
497 |
return response
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
async def geturl(self, request: HtmxRequest) -> Response:
|
500 |
# assert (htmx := request.scope["htmx"])
|
501 |
if len(self.url_dict) == 0:
|
@@ -745,6 +916,92 @@ class DemoCase:
|
|
745 |
await asyncio.sleep(0.01)
|
746 |
return response
|
747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
async def streamswitch(self, request: HtmxRequest) -> Response:
|
749 |
# assert (htmx := request.scope["htmx"])
|
750 |
template = "partials/ack.html"
|
@@ -815,10 +1072,10 @@ class DemoCase:
|
|
815 |
status_code = 201
|
816 |
# setup object counter & speed estimator
|
817 |
region_points = [
|
818 |
-
(5, -
|
819 |
-
(5,
|
820 |
-
(self.FRAME_WIDTH - 5,
|
821 |
-
(self.FRAME_WIDTH - 5, -
|
822 |
]
|
823 |
self.counter = object_counter.ObjectCounter()
|
824 |
self.counter.set_args(
|
@@ -861,7 +1118,9 @@ class DemoCase:
|
|
861 |
if await request.is_disconnected():
|
862 |
yield {
|
863 |
"event": "evt_in_counts",
|
864 |
-
"id": datetime.now(
|
|
|
|
|
865 |
"retry": RETRY_TIMEOUT_MILSEC,
|
866 |
"data": "..."
|
867 |
}
|
@@ -874,7 +1133,9 @@ class DemoCase:
|
|
874 |
if (self.counter is not None and incounts_msg):
|
875 |
yield {
|
876 |
"event": "evt_in_counts",
|
877 |
-
"id": datetime.now(
|
|
|
|
|
878 |
"retry": RETRY_TIMEOUT_MILSEC,
|
879 |
"data": f"{self.counter.in_counts}"
|
880 |
}
|
@@ -882,7 +1143,9 @@ class DemoCase:
|
|
882 |
if _stop_sse is False:
|
883 |
yield {
|
884 |
"event": "evt_in_counts",
|
885 |
-
"id": datetime.now(
|
|
|
|
|
886 |
"retry": RETRY_TIMEOUT_MILSEC,
|
887 |
"data": "---"
|
888 |
}
|
@@ -898,7 +1161,9 @@ class DemoCase:
|
|
898 |
if await request.is_disconnected():
|
899 |
yield {
|
900 |
"event": "evt_out_counts",
|
901 |
-
"id": datetime.now(
|
|
|
|
|
902 |
"retry": RETRY_TIMEOUT_MILSEC,
|
903 |
"data": "..."
|
904 |
}
|
@@ -911,62 +1176,46 @@ class DemoCase:
|
|
911 |
if (self.counter is not None and outcounts_msg):
|
912 |
yield {
|
913 |
"event": "evt_out_counts",
|
914 |
-
"id": datetime.now(
|
|
|
|
|
915 |
"retry": RETRY_TIMEOUT_MILSEC,
|
916 |
"data": f"{self.counter.out_counts}"
|
917 |
}
|
918 |
else:
|
919 |
if _stop_sse is False:
|
920 |
yield {
|
921 |
-
|
922 |
-
|
923 |
-
"
|
924 |
-
|
|
|
|
|
925 |
}
|
926 |
_stop_sse = True
|
927 |
await asyncio.sleep(EVT_STREAM_DELAY_SEC)
|
928 |
return EventSourceResponse(event_generator())
|
929 |
|
930 |
|
|
|
|
|
|
|
|
|
931 |
# instantiate a demo case
|
932 |
demo_case = DemoCase(YOLO_VERBOSE=False)
|
933 |
-
demo_case.set_frame_reduction(frame_reduction=
|
934 |
-
demo_case.load_model(
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
demo_case.FRAME_WIDTH - 5,
|
941 |
-
demo_case.roi_thickness_half + demo_case.roi_height
|
942 |
-
),
|
943 |
-
(
|
944 |
-
demo_case.FRAME_WIDTH - 5,
|
945 |
-
-demo_case.roi_thickness_half + demo_case.roi_height
|
946 |
-
),
|
947 |
-
]
|
948 |
-
demo_case.counter = object_counter.ObjectCounter()
|
949 |
-
demo_case.counter.set_args(
|
950 |
-
view_img=False,
|
951 |
-
reg_pts=region_points,
|
952 |
-
classes_names=demo_case.model.names,
|
953 |
-
draw_tracks=False,
|
954 |
-
draw_boxes=False,
|
955 |
-
draw_reg_pts=True,
|
956 |
)
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
demo_case.speed_obj = speed_estimation.SpeedEstimator()
|
963 |
-
demo_case.speed_obj.set_args(
|
964 |
-
reg_pts=line_points,
|
965 |
-
names=demo_case.model.names,
|
966 |
-
view_img=False
|
967 |
-
)
|
968 |
-
logging.info([f"{x}" for x in list(demo_case.url_dict.keys())])
|
969 |
-
logging.info([f"{x}" for x in list(demo_case.model_dict.keys())])
|
970 |
|
971 |
# setup webgear server
|
972 |
options = {
|
@@ -976,16 +1225,17 @@ options = {
|
|
976 |
"jpeg_compression_fastdct": True,
|
977 |
"jpeg_compression_fastupsample": True,
|
978 |
}
|
|
|
979 |
web = WebGear(
|
980 |
logging=True, **options
|
981 |
)
|
982 |
# config webgear server
|
983 |
-
web.config["generator"] = demo_case.
|
984 |
web.config["middleware"] = [Middleware(HtmxMiddleware)]
|
985 |
web.routes.append(Mount("/static", static, name="static"))
|
986 |
-
web.routes.append(
|
987 |
-
|
988 |
-
)
|
989 |
routes_dict = {
|
990 |
"models": (demo_case.models, ["GET"]),
|
991 |
"getmodel": (demo_case.getmodel, ["GET"]),
|
@@ -994,6 +1244,9 @@ routes_dict = {
|
|
994 |
"addurl": (demo_case.addurl, ["POST"]),
|
995 |
"geturl": (demo_case.geturl, ["GET"]),
|
996 |
"seturl": (demo_case.seturl, ["POST"]),
|
|
|
|
|
|
|
997 |
"streamswitch": (demo_case.streamswitch, ["POST"]),
|
998 |
"trackingswitch": (demo_case.trackingswitch, ["POST"]),
|
999 |
}
|
@@ -1011,3 +1264,14 @@ web.routes.append(Route(
|
|
1011 |
endpoint=demo_case.sse_outcounts,
|
1012 |
name="sseoutcounts"
|
1013 |
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import numpy as np
|
6 |
from pathlib import Path
|
7 |
import torch
|
8 |
+
from zoneinfo import ZoneInfo
|
9 |
from starlette.middleware import Middleware
|
10 |
from starlette.responses import StreamingResponse, Response
|
11 |
from starlette.requests import Request
|
|
|
20 |
from vidgear.gears import CamGear
|
21 |
from vidgear.gears.asyncio import WebGear
|
22 |
from vidgear.gears.asyncio.helper import reducer
|
23 |
+
from helper import (
|
24 |
+
draw_text, make_table_from_dict_multiselect, make_table_from_dict, try_site
|
25 |
+
)
|
26 |
|
27 |
|
28 |
HERE = Path(__file__).parent
|
|
|
34 |
# Create and configure logger
|
35 |
# logger = logging.getLogger(__name__).addHandler(logging.NullHandler())
|
36 |
logging.basicConfig(
|
37 |
+
format='%(asctime)s %(name)-8s->%(module)-20s->%(funcName)-20s:%(lineno)-4s::%(levelname)-8s %(message)s', # noqa
|
38 |
level=logging.INFO
|
39 |
)
|
40 |
|
|
|
58 |
"y8huge": "./data/models/yolov8x.pt",
|
59 |
}
|
60 |
self.model_choice_default: str = "y8small"
|
61 |
+
self.model_choice: str = self.model_choice_default
|
62 |
# predefined youtube live stream urls
|
63 |
self.url_dict: dict = {
|
64 |
"Peace Bridge US": "https://youtu.be/9En2186vo5g",
|
|
|
71 |
"Port Everglades-2": "https://youtu.be/Nhuu1QsW5LI",
|
72 |
"Port Everglades-3": "https://youtu.be/Lpm-C_Gz6yM",
|
73 |
}
|
74 |
+
self.obj_dict: dict = {
|
75 |
+
"person": 0,
|
76 |
+
"bicycle": 1,
|
77 |
+
"car": 2,
|
78 |
+
"motorcycle": 3,
|
79 |
+
"airplane": 4,
|
80 |
+
"bus": 5,
|
81 |
+
"train": 6,
|
82 |
+
"truck": 7,
|
83 |
+
"boat": 8,
|
84 |
+
"traffic light": 9,
|
85 |
+
"fire hydrant": 10,
|
86 |
+
"stop sign": 11,
|
87 |
+
"parking meter": 12
|
88 |
+
}
|
89 |
self.cam_loc_default: str = "Peace Bridge US"
|
|
|
|
|
90 |
self.cam_loc: str = self.cam_loc_default
|
91 |
+
self.frame_reduction: int = 35
|
92 |
+
|
93 |
+
# run time parameters that are from user input
|
94 |
+
self.roi_height_default: int = int(FRAME_HEIGHT / 2)
|
95 |
+
self.roi_height: int = self.roi_height_default
|
96 |
+
self.roi_thickness_half_default: int = 30
|
97 |
+
self.roi_thickness_half: int = self.roi_thickness_half_default
|
98 |
+
self.obj_class_id_default: list[int] = [2, 3, 5, 7]
|
99 |
+
self.obj_class_id: list[int] = self.obj_class_id_default
|
100 |
+
self.conf_threshold: float = 0.25
|
101 |
+
self.iou_threshold: float = 0.7
|
102 |
+
self.use_FP16: bool = False
|
103 |
+
self.use_stream_buffer: bool = True
|
104 |
self.stream0: CamGear = None
|
105 |
self.stream1: CamGear = None
|
106 |
self.counter = None
|
107 |
self.speed_obj = None
|
108 |
|
109 |
+
# define some logic flow control booleans
|
110 |
+
self._is_running: bool = False
|
111 |
+
self._is_tracking: bool = False
|
112 |
+
self._roi_changed: bool = False
|
113 |
+
|
114 |
def load_model(
|
115 |
self,
|
116 |
model_choice: str = "y8small",
|
|
|
147 |
)
|
148 |
|
149 |
# setup some configs
|
150 |
+
self.conf_threshold = conf_threshold if conf_threshold > 0.0 else 0.25 # noqa
|
151 |
+
self.iou_threshold = iou_threshold if iou_threshold > 0.0 else 0.7
|
152 |
+
self.use_FP16 = use_FP16
|
153 |
+
self.use_stream_buffer = use_stream_buffer
|
154 |
logging.info(
|
155 |
f"{self.model_choice}: conf={self.conf_threshold:.2f} | "
|
156 |
f"iou={self.iou_threshold:.2f} | FP16={self.use_FP16} | "
|
|
|
168 |
if (bool(cam_loc_key) is False or bool(cam_loc_val) is False):
|
169 |
self.cam_loc = self.cam_loc_default
|
170 |
logging.warning(
|
171 |
+
f'input cam_loc_key, cam_loc_val pair invalid, use default '
|
172 |
+
f'{{{self.cam_loc_default}: '
|
173 |
+
f'{self.url_dict[self.cam_loc_default]}}}'
|
174 |
)
|
175 |
elif cam_loc_key not in self.url_dict:
|
176 |
if try_site(self.url_dict[self.cam_loc]):
|
|
|
195 |
f'use {{{self.cam_loc}: {self.url_dict[self.cam_loc]}}} as source'
|
196 |
)
|
197 |
|
198 |
+
def select_obj_class_id(
|
199 |
+
self,
|
200 |
+
obj_names: list[str] = [
|
201 |
+
"person", "bicycle", "car", "motorcycle", "airplane", "bus",
|
202 |
+
"train", "truck", "boat", "traffic light", "fire hydrant",
|
203 |
+
"stop sign", "parking meter"
|
204 |
+
]
|
205 |
+
) -> None:
|
206 |
+
"""
|
207 |
+
select object class id list based on the input obj_names str list
|
208 |
+
"""
|
209 |
+
if (bool(obj_names) is False):
|
210 |
+
self.obj_class_id = self.obj_class_id_default
|
211 |
+
logging.warning(
|
212 |
+
f'input obj_names invalid, use default id {self.obj_class_id_default}'
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
obj_class_id = []
|
216 |
+
for name in obj_names:
|
217 |
+
if name in list(self.obj_dict.keys()):
|
218 |
+
obj_class_id.append(self.obj_dict[name])
|
219 |
+
if (len(obj_class_id) == 0):
|
220 |
+
self.obj_class_id = self.obj_class_id_default
|
221 |
+
logging.warning(
|
222 |
+
f'input obj_names invalid, use default id '
|
223 |
+
f'{self.obj_class_id_default}'
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
self.obj_class_id = obj_class_id
|
227 |
+
logging.info(f'object class id set as {self.obj_class_id}')
|
228 |
+
|
229 |
+
# def set_roi(self, roi_height: int = 360, roi_thickness_half: int = 30):
|
230 |
+
def set_roi(self, roi_height: int = 360):
|
231 |
+
if (roi_height < 120 or roi_height > 600):
|
232 |
self.roi_height = int(self.FRAME_HEIGHT / 2)
|
233 |
logging.warning(
|
234 |
f'roi_height invalid, use default {int(self.FRAME_HEIGHT / 2)}'
|
|
|
237 |
self.roi_height = roi_height
|
238 |
logging.info(f'roi_height is set at {self.roi_height}')
|
239 |
|
240 |
+
self.roi_thickness_half = self.roi_thickness_half_default
|
241 |
+
|
242 |
+
'''
|
243 |
+
if (
|
244 |
+
roi_thickness_half > 0 and
|
245 |
+
roi_thickness_half < int(self.FRAME_HEIGHT / 2)
|
246 |
+
):
|
247 |
+
if (self.roi_height + roi_thickness_half > self.FRAME_HEIGHT):
|
248 |
+
self.roi_thickness_half = self.FRAME_HEIGHT - self.roi_height
|
249 |
+
elif (self.roi_height - roi_thickness_half < 0):
|
250 |
+
self.roi_thickness_half = self.roi_height
|
251 |
+
else:
|
252 |
+
self.roi_thickness_half = roi_thickness_half
|
253 |
+
logging.info(
|
254 |
+
f'roi_thickness_half is set at {self.roi_thickness_half}'
|
255 |
+
)
|
256 |
+
else:
|
257 |
+
self.roi_thickness_half = self.roi_thickness_half_default
|
258 |
+
logging.warning('roi_half_thickness invalid, use default 30')
|
259 |
+
'''
|
260 |
+
|
261 |
+
def set_frame_reduction(self, frame_reduction: int = 35):
|
262 |
if (frame_reduction < 0 or frame_reduction > 100):
|
263 |
+
self.frame_reduction = 35
|
264 |
logging.warning(
|
265 |
f'frame_reduction:{frame_reduction} invalid, '
|
266 |
+
f'use default value 35'
|
267 |
)
|
268 |
else:
|
269 |
self.frame_reduction = frame_reduction
|
|
|
299 |
# do something with your OpenCV frame here
|
300 |
draw_text(
|
301 |
img=frame,
|
302 |
+
text=datetime.now(
|
303 |
+
tz=ZoneInfo("America/Los_Angeles")
|
304 |
+
).strftime("%m/%d/%Y %H:%M:%S") + " PDT",
|
305 |
+
pos=(int(self.FRAME_WIDTH - 500), 50),
|
306 |
font=cv2.FONT_HERSHEY_SIMPLEX,
|
307 |
font_scale=1,
|
308 |
font_thickness=2,
|
|
|
310 |
text_color=(0, 255, 255),
|
311 |
text_color_bg=(0, 0, 0),
|
312 |
)
|
313 |
+
# reducer frame size for performance, percentage int
|
314 |
frame = await reducer(
|
315 |
frame, percentage=self.frame_reduction
|
316 |
)
|
|
|
324 |
|
325 |
if self.stream0 is not None:
|
326 |
self.stream0.stop()
|
327 |
+
while self.stream0.read() is not None:
|
328 |
+
continue
|
329 |
self.stream0 = None
|
330 |
self._is_running = False
|
331 |
except asyncio.CancelledError:
|
332 |
if self.stream0 is not None:
|
333 |
self.stream0.stop()
|
334 |
+
while self.stream0.read() is not None:
|
335 |
+
continue
|
336 |
self.stream0 = None
|
337 |
self._is_running = False
|
338 |
logging.warning(
|
|
|
391 |
logging=True
|
392 |
).start()
|
393 |
|
394 |
+
if (self._is_tracking and self.stream1 is not None):
|
395 |
+
if self.counter is None or self._roi_changed:
|
396 |
# setup object counter & speed estimator
|
397 |
region_points = [
|
398 |
(5, -self.roi_thickness_half + self.roi_height),
|
|
|
401 |
self.FRAME_WIDTH - 5,
|
402 |
self.roi_thickness_half + self.roi_height
|
403 |
),
|
404 |
+
(
|
405 |
self.FRAME_WIDTH - 5,
|
406 |
-self.roi_thickness_half + self.roi_height
|
407 |
),
|
|
|
415 |
draw_boxes=False,
|
416 |
draw_reg_pts=True,
|
417 |
)
|
418 |
+
self._roi_changed = False
|
419 |
|
420 |
+
if self.speed_obj is None or self._roi_changed:
|
421 |
# Init speed estimator
|
422 |
line_points = [
|
423 |
(5, self.roi_height),
|
|
|
429 |
names=self.model.names,
|
430 |
view_img=False
|
431 |
)
|
432 |
+
self._roi_changed = False
|
433 |
|
434 |
try:
|
435 |
while (self.stream1 is not None and self._is_running):
|
436 |
+
if self._roi_changed:
|
437 |
+
# setup object counter & speed estimator
|
438 |
+
region_points = [
|
439 |
+
(5, -self.roi_thickness_half + self.roi_height),
|
440 |
+
(5, self.roi_thickness_half + self.roi_height),
|
441 |
+
(
|
442 |
+
self.FRAME_WIDTH - 5,
|
443 |
+
self.roi_thickness_half + self.roi_height
|
444 |
+
),
|
445 |
+
(
|
446 |
+
self.FRAME_WIDTH - 5,
|
447 |
+
-self.roi_thickness_half + self.roi_height
|
448 |
+
),
|
449 |
+
]
|
450 |
+
self.counter = object_counter.ObjectCounter()
|
451 |
+
self.counter.set_args(
|
452 |
+
view_img=False,
|
453 |
+
reg_pts=region_points,
|
454 |
+
classes_names=self.model.names,
|
455 |
+
draw_tracks=False,
|
456 |
+
draw_boxes=False,
|
457 |
+
draw_reg_pts=True,
|
458 |
+
)
|
459 |
+
# Init speed estimator
|
460 |
+
line_points = [
|
461 |
+
(5, self.roi_height),
|
462 |
+
(self.FRAME_WIDTH - 5, self.roi_height)
|
463 |
+
]
|
464 |
+
self.speed_obj = speed_estimation.SpeedEstimator()
|
465 |
+
self.speed_obj.set_args(
|
466 |
+
reg_pts=line_points,
|
467 |
+
names=self.model.names,
|
468 |
+
view_img=False
|
469 |
+
)
|
470 |
+
self._roi_changed = False
|
471 |
+
|
472 |
# read frame from provided source
|
473 |
frame = self.stream1.read()
|
474 |
if frame is None:
|
|
|
480 |
# do something with your OpenCV frame here
|
481 |
draw_text(
|
482 |
img=frame,
|
483 |
+
text=datetime.now(
|
484 |
+
tz=ZoneInfo("America/Los_Angeles")
|
485 |
+
).strftime("%m/%d/%Y %H:%M:%S") + " PDT",
|
486 |
+
pos=(self.FRAME_WIDTH - 500, 50),
|
487 |
font=cv2.FONT_HERSHEY_SIMPLEX,
|
488 |
font_scale=1,
|
489 |
font_thickness=2,
|
|
|
493 |
)
|
494 |
|
495 |
frame_tagged = frame
|
496 |
+
if (
|
497 |
+
self._is_tracking and self.model is not None
|
498 |
+
and self.speed_obj is not None
|
499 |
+
and self.counter is not None
|
500 |
+
and self._roi_changed is False
|
501 |
+
):
|
502 |
# YOLOv8 tracking, persisting tracks between frames
|
503 |
results = self.model.track(
|
504 |
source=frame,
|
|
|
536 |
|
537 |
if self.stream1 is not None:
|
538 |
self.stream1.stop()
|
539 |
+
while self.stream1.read() is not None:
|
540 |
+
continue
|
541 |
self.stream1 = None
|
542 |
self._is_tracking = False
|
543 |
self._is_running = False
|
|
|
545 |
except asyncio.CancelledError:
|
546 |
if self.stream1 is not None:
|
547 |
self.stream1.stop()
|
548 |
+
while self.stream1.read() is not None:
|
549 |
+
continue
|
550 |
self.stream1 = None
|
551 |
self._is_tracking = False
|
552 |
self._is_running = False
|
553 |
logging.warning(
|
554 |
"client disconnected in frame1_producer"
|
555 |
)
|
556 |
+
frame = (np.random.standard_normal([
|
557 |
+
self.FRAME_HEIGHT, self.FRAME_WIDTH, 3
|
558 |
+
]) * 255).astype(np.uint8)
|
559 |
+
frame = await reducer(
|
560 |
+
frame, percentage=self.frame_reduction
|
561 |
+
)
|
562 |
+
img_encoded = cv2.imencode(".jpg", frame)[1].tobytes()
|
563 |
+
logging.info(
|
564 |
+
f"_is_running is {self._is_running} in frame0_producer"
|
565 |
+
)
|
566 |
+
yield (
|
567 |
+
b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" +
|
568 |
+
img_encoded + b"\r\n"
|
569 |
+
)
|
570 |
+
await asyncio.sleep(0.00001)
|
571 |
+
|
572 |
else:
|
573 |
if self._is_running is True:
|
574 |
pass
|
|
|
644 |
await asyncio.sleep(0.01)
|
645 |
return response
|
646 |
|
647 |
+
async def objects(self, request: HtmxRequest) -> Response:
|
648 |
+
# assert (htmx := request.scope["htmx"])
|
649 |
+
if len(self.obj_dict) == 0:
|
650 |
+
template = "partials/ack.html"
|
651 |
+
table_contents = ["object list unavailable!"]
|
652 |
+
context = {"request": request, "table": table_contents}
|
653 |
+
response = templates.TemplateResponse(
|
654 |
+
template, context, status_code=200
|
655 |
+
)
|
656 |
+
# response.headers['Hx-Retarget'] = '#set-model-ack'
|
657 |
+
else:
|
658 |
+
template = "partials/object_list.html"
|
659 |
+
table_contents = make_table_from_dict_multiselect(
|
660 |
+
self.obj_dict, self.obj_class_id
|
661 |
+
)
|
662 |
+
context = {"request": request, "table": table_contents}
|
663 |
+
response = templates.TemplateResponse(
|
664 |
+
template, context, status_code=200
|
665 |
+
)
|
666 |
+
|
667 |
+
await asyncio.sleep(0.001)
|
668 |
+
return response
|
669 |
+
|
670 |
async def geturl(self, request: HtmxRequest) -> Response:
|
671 |
# assert (htmx := request.scope["htmx"])
|
672 |
if len(self.url_dict) == 0:
|
|
|
916 |
await asyncio.sleep(0.01)
|
917 |
return response
|
918 |
|
919 |
+
async def selectobjects(self, request: HtmxRequest) -> Response:
|
920 |
+
# assert (htmx := request.scope["htmx"])
|
921 |
+
template = "partials/ack.html"
|
922 |
+
try:
|
923 |
+
req_json = await request.json()
|
924 |
+
except RuntimeError:
|
925 |
+
table_contents = ["receive channel unavailable!"]
|
926 |
+
context = {"request": request, "table": table_contents}
|
927 |
+
response = templates.TemplateResponse(
|
928 |
+
template, context, status_code=200
|
929 |
+
)
|
930 |
+
await asyncio.sleep(0.01)
|
931 |
+
return response
|
932 |
+
|
933 |
+
if ("payload" in req_json and "object_id" in req_json["payload"]):
|
934 |
+
logging.info(f"requested_ids: {req_json['payload']}")
|
935 |
+
req_ids = req_json["payload"]["object_id"]
|
936 |
+
if len(req_ids) > 0:
|
937 |
+
self.obj_class_id = [
|
938 |
+
int(id) for id in req_ids
|
939 |
+
if int(id) in self.obj_dict.values()
|
940 |
+
]
|
941 |
+
if len(self.obj_class_id) > 0:
|
942 |
+
table_contents = [
|
943 |
+
f"{len(self.obj_class_id)} object types selected"
|
944 |
+
]
|
945 |
+
else:
|
946 |
+
self.obj_class_id = self.obj_class_id_default
|
947 |
+
table_contents = [
|
948 |
+
"invalid objects selection, use default object types"
|
949 |
+
]
|
950 |
+
else:
|
951 |
+
table_contents = ["invalid POST request! need at least one object type"]
|
952 |
+
|
953 |
+
context = {"request": request, "table": table_contents}
|
954 |
+
response = templates.TemplateResponse(
|
955 |
+
template, context, status_code=200
|
956 |
+
)
|
957 |
+
|
958 |
+
await asyncio.sleep(0.01)
|
959 |
+
return response
|
960 |
+
|
961 |
+
async def setroi(self, request: HtmxRequest) -> Response:
|
962 |
+
# assert (htmx := request.scope["htmx"])
|
963 |
+
template = "partials/ack.html"
|
964 |
+
try:
|
965 |
+
req_json = await request.json()
|
966 |
+
except RuntimeError:
|
967 |
+
table_contents = ["receive channel unavailable!"]
|
968 |
+
context = {"request": request, "table": table_contents}
|
969 |
+
response = templates.TemplateResponse(
|
970 |
+
template, context, status_code=200
|
971 |
+
)
|
972 |
+
await asyncio.sleep(0.01)
|
973 |
+
return response
|
974 |
+
|
975 |
+
if ("payload" in req_json and "roi_height" in req_json["payload"]):
|
976 |
+
logging.info(f"{req_json['payload']}")
|
977 |
+
req_height = (int)(req_json["payload"]["roi_height"])
|
978 |
+
if (
|
979 |
+
req_height >= 120 and req_height <= 600 and
|
980 |
+
req_height < self.FRAME_HEIGHT
|
981 |
+
):
|
982 |
+
self.roi_height = self.FRAME_HEIGHT - req_height
|
983 |
+
table_contents = [
|
984 |
+
f"roi_height set at "
|
985 |
+
f"{self.FRAME_HEIGHT - self.roi_height}px"
|
986 |
+
]
|
987 |
+
else:
|
988 |
+
self.roi_height = self.roi_height_default
|
989 |
+
table_contents = [
|
990 |
+
f"invalid roi_height request, use default"
|
991 |
+
f"{self.FRAME_HEIGHT - self.roi_height_default}px"
|
992 |
+
]
|
993 |
+
self._roi_changed = True
|
994 |
+
else:
|
995 |
+
table_contents = ["invalid POST request! need a valid roi_height"]
|
996 |
+
|
997 |
+
context = {"request": request, "table": table_contents}
|
998 |
+
response = templates.TemplateResponse(
|
999 |
+
template, context, status_code=200
|
1000 |
+
)
|
1001 |
+
|
1002 |
+
await asyncio.sleep(0.01)
|
1003 |
+
return response
|
1004 |
+
|
1005 |
async def streamswitch(self, request: HtmxRequest) -> Response:
|
1006 |
# assert (htmx := request.scope["htmx"])
|
1007 |
template = "partials/ack.html"
|
|
|
1072 |
status_code = 201
|
1073 |
# setup object counter & speed estimator
|
1074 |
region_points = [
|
1075 |
+
(5, -20 + self.roi_height),
|
1076 |
+
(5, 20 + self.roi_height),
|
1077 |
+
(self.FRAME_WIDTH - 5, 20 + self.roi_height),
|
1078 |
+
(self.FRAME_WIDTH - 5, -20 + self.roi_height),
|
1079 |
]
|
1080 |
self.counter = object_counter.ObjectCounter()
|
1081 |
self.counter.set_args(
|
|
|
1118 |
if await request.is_disconnected():
|
1119 |
yield {
|
1120 |
"event": "evt_in_counts",
|
1121 |
+
"id": datetime.now(
|
1122 |
+
tz=ZoneInfo("America/Los_Angeles")
|
1123 |
+
).strftime("%m/%d/%Y %H:%M:%S"),
|
1124 |
"retry": RETRY_TIMEOUT_MILSEC,
|
1125 |
"data": "..."
|
1126 |
}
|
|
|
1133 |
if (self.counter is not None and incounts_msg):
|
1134 |
yield {
|
1135 |
"event": "evt_in_counts",
|
1136 |
+
"id": datetime.now(
|
1137 |
+
tz=ZoneInfo("America/Los_Angeles")
|
1138 |
+
).strftime("%m/%d/%Y %H:%M:%S"),
|
1139 |
"retry": RETRY_TIMEOUT_MILSEC,
|
1140 |
"data": f"{self.counter.in_counts}"
|
1141 |
}
|
|
|
1143 |
if _stop_sse is False:
|
1144 |
yield {
|
1145 |
"event": "evt_in_counts",
|
1146 |
+
"id": datetime.now(
|
1147 |
+
tz=ZoneInfo("America/Los_Angeles")
|
1148 |
+
).strftime("%m/%d/%Y %H:%M:%S"),
|
1149 |
"retry": RETRY_TIMEOUT_MILSEC,
|
1150 |
"data": "---"
|
1151 |
}
|
|
|
1161 |
if await request.is_disconnected():
|
1162 |
yield {
|
1163 |
"event": "evt_out_counts",
|
1164 |
+
"id": datetime.now(
|
1165 |
+
tz=ZoneInfo("America/Los_Angeles")
|
1166 |
+
).strftime("%m/%d/%Y %H:%M:%S"),
|
1167 |
"retry": RETRY_TIMEOUT_MILSEC,
|
1168 |
"data": "..."
|
1169 |
}
|
|
|
1176 |
if (self.counter is not None and outcounts_msg):
|
1177 |
yield {
|
1178 |
"event": "evt_out_counts",
|
1179 |
+
"id": datetime.now(
|
1180 |
+
tz=ZoneInfo("America/Los_Angeles")
|
1181 |
+
).strftime("%m/%d/%Y %H:%M:%S"),
|
1182 |
"retry": RETRY_TIMEOUT_MILSEC,
|
1183 |
"data": f"{self.counter.out_counts}"
|
1184 |
}
|
1185 |
else:
|
1186 |
if _stop_sse is False:
|
1187 |
yield {
|
1188 |
+
"event": "evt_out_counts",
|
1189 |
+
"id": datetime.now(
|
1190 |
+
tz=ZoneInfo("America/Los_Angeles")
|
1191 |
+
).strftime("%m/%d/%Y %H:%M:%S"),
|
1192 |
+
"retry": RETRY_TIMEOUT_MILSEC,
|
1193 |
+
"data": "---"
|
1194 |
}
|
1195 |
_stop_sse = True
|
1196 |
await asyncio.sleep(EVT_STREAM_DELAY_SEC)
|
1197 |
return EventSourceResponse(event_generator())
|
1198 |
|
1199 |
|
1200 |
+
# is_huggingface = False
|
1201 |
+
# define the host url and port for webgear server
|
1202 |
+
# HOST_WEBGEAR, PORT_WEBGEAR = "localhost", 8080
|
1203 |
+
|
1204 |
# instantiate a demo case
|
1205 |
demo_case = DemoCase(YOLO_VERBOSE=False)
|
1206 |
+
demo_case.set_frame_reduction(frame_reduction=35)
|
1207 |
+
demo_case.load_model(
|
1208 |
+
model_choice="y8small",
|
1209 |
+
conf_threshold=0.1,
|
1210 |
+
iou_threshold=0.6,
|
1211 |
+
use_FP16=False,
|
1212 |
+
use_stream_buffer=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1213 |
)
|
1214 |
+
logging.info(f"url_dict: {demo_case.url_dict}")
|
1215 |
+
logging.info(f"model_dict: {demo_case.model_dict}")
|
1216 |
+
logging.info(f"obj_dict: {demo_case.obj_dict}")
|
1217 |
+
logging.info(f"obj_class_id: {demo_case.obj_class_id}")
|
1218 |
+
# logging.info(f"model.names: {demo_case.model.names}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1219 |
|
1220 |
# setup webgear server
|
1221 |
options = {
|
|
|
1225 |
"jpeg_compression_fastdct": True,
|
1226 |
"jpeg_compression_fastupsample": True,
|
1227 |
}
|
1228 |
+
|
1229 |
web = WebGear(
|
1230 |
logging=True, **options
|
1231 |
)
|
1232 |
# config webgear server
|
1233 |
+
web.config["generator"] = demo_case.frame1_producer
|
1234 |
web.config["middleware"] = [Middleware(HtmxMiddleware)]
|
1235 |
web.routes.append(Mount("/static", static, name="static"))
|
1236 |
+
# web.routes.append(
|
1237 |
+
# Route("/video1", endpoint=demo_case.custom_video_response)
|
1238 |
+
# )
|
1239 |
routes_dict = {
|
1240 |
"models": (demo_case.models, ["GET"]),
|
1241 |
"getmodel": (demo_case.getmodel, ["GET"]),
|
|
|
1244 |
"addurl": (demo_case.addurl, ["POST"]),
|
1245 |
"geturl": (demo_case.geturl, ["GET"]),
|
1246 |
"seturl": (demo_case.seturl, ["POST"]),
|
1247 |
+
"objects": (demo_case.objects, ["GET"]),
|
1248 |
+
"selectobjects": (demo_case.selectobjects, ["POST"]),
|
1249 |
+
"setroi": (demo_case.setroi, ["POST"]),
|
1250 |
"streamswitch": (demo_case.streamswitch, ["POST"]),
|
1251 |
"trackingswitch": (demo_case.trackingswitch, ["POST"]),
|
1252 |
}
|
|
|
1264 |
endpoint=demo_case.sse_outcounts,
|
1265 |
name="sseoutcounts"
|
1266 |
))
|
1267 |
+
|
1268 |
+
# if is_huggingface is False:
|
1269 |
+
# # run this app on Uvicorn server at address http://localhost:8080/
|
1270 |
+
# uvicorn.run(
|
1271 |
+
# web(), host=HOST_WEBGEAR, port=PORT_WEBGEAR, log_level="info"
|
1272 |
+
# )
|
1273 |
+
# # close app safely
|
1274 |
+
# web.shutdown()
|
1275 |
+
#
|
1276 |
+
# or launch it using cli --
|
1277 |
+
# uvicorn webapp:web --host "localhost" --port 8080 --reload
|