Spaces:
Runtime error
Runtime error
import asyncio | |
from datetime import datetime | |
import logging | |
import cv2 | |
import numpy as np | |
from pathlib import Path | |
import torch | |
from zoneinfo import ZoneInfo | |
from starlette.middleware import Middleware | |
from starlette.responses import StreamingResponse, Response | |
from starlette.requests import Request | |
from starlette.routing import Mount, Route | |
from starlette.staticfiles import StaticFiles | |
from starlette.templating import Jinja2Templates | |
from sse_starlette import EventSourceResponse | |
from asgi_htmx import HtmxMiddleware | |
from asgi_htmx import HtmxRequest | |
from ultralytics import YOLO | |
from ultralytics_solutions_modified import object_counter, speed_estimation | |
from vidgear.gears import CamGear | |
from vidgear.gears.asyncio import WebGear | |
from vidgear.gears.asyncio.helper import reducer | |
from helper import ( | |
draw_text, make_table_from_dict_multiselect, make_table_from_dict, try_site | |
) | |
HERE = Path(__file__).parent | |
static = StaticFiles(directory=HERE / ".vidgear/webgear/static") | |
templates = Jinja2Templates(directory=HERE / ".vidgear/webgear/templates") | |
EVT_STREAM_DELAY_SEC = 0.05 # second | |
RETRY_TIMEOUT_MILSEC = 15000 # milisecond | |
# Create and configure logger | |
# logger = logging.getLogger(__name__).addHandler(logging.NullHandler()) | |
logging.basicConfig( | |
format='%(asctime)s %(name)-8s->%(module)-20s->%(funcName)-20s:%(lineno)-4s::%(levelname)-8s %(message)s', # noqa | |
level=logging.INFO | |
) | |
class DemoCase: | |
def __init__( | |
self, | |
FRAME_WIDTH: int = 1280, | |
FRAME_HEIGHT: int = 720, | |
YOLO_VERBOSE: bool = True | |
): | |
self.FRAME_WIDTH: int = FRAME_WIDTH | |
self.FRAME_HEIGHT: int = FRAME_HEIGHT | |
self.YOLO_VERBOSE: bool = YOLO_VERBOSE | |
self.STREAM_RESOLUTION: str = "720p" | |
# predefined yolov8 model references | |
self.model_dict: dict = { | |
"y8nano": "./data/models/yolov8n.pt", | |
"y8small": "./data/models/yolov8s.pt", | |
"y8medium": "./data/models/yolov8m.pt", | |
"y8large": "./data/models/yolov8l.pt", | |
"y8huge": "./data/models/yolov8x.pt", | |
} | |
self.model_choice_default: str = "y8small" | |
self.model_choice: str = self.model_choice_default | |
# predefined youtube live stream urls | |
self.url_dict: dict = { | |
"Peace Bridge US": "https://youtu.be/9En2186vo5g", | |
"Peace Bridge CA": "https://youtu.be/WPMgP2C3_co", | |
"San Marcos TX": "https://youtu.be/E8LsKcVpL5A", | |
"4Corners Downtown": "https://youtu.be/ByED80IKdIU", | |
"Gangnam Seoul": "https://youtu.be/3ottn7kfRuc", | |
"Time Square NY": "https://youtu.be/QTTTY_ra2Tg", | |
"Port Everglades-1": "https://youtu.be/67-73mgWDf0", | |
"Port Everglades-2": "https://youtu.be/Nhuu1QsW5LI", | |
"Port Everglades-3": "https://youtu.be/Lpm-C_Gz6yM", | |
} | |
self.obj_dict: dict = { | |
"person": 0, | |
"bicycle": 1, | |
"car": 2, | |
"motorcycle": 3, | |
"airplane": 4, | |
"bus": 5, | |
"train": 6, | |
"truck": 7, | |
"boat": 8, | |
"traffic light": 9, | |
"fire hydrant": 10, | |
"stop sign": 11, | |
"parking meter": 12 | |
} | |
self.cam_loc_default: str = "Peace Bridge US" | |
self.cam_loc: str = self.cam_loc_default | |
self.frame_reduction: int = 35 | |
# run time parameters that are from user input | |
self.roi_height_default: int = int(FRAME_HEIGHT / 2) | |
self.roi_height: int = self.roi_height_default | |
self.roi_thickness_half_default: int = 30 | |
self.roi_thickness_half: int = self.roi_thickness_half_default | |
self.obj_class_id_default: list[int] = [2, 3, 5, 7] | |
self.obj_class_id: list[int] = self.obj_class_id_default | |
self.conf_threshold: float = 0.25 | |
self.iou_threshold: float = 0.7 | |
self.use_FP16: bool = False | |
self.use_stream_buffer: bool = True | |
self.stream0: CamGear = None | |
self.stream1: CamGear = None | |
self.counter = None | |
self.speed_obj = None | |
# define some logic flow control booleans | |
self._is_running: bool = False | |
self._is_tracking: bool = False | |
self._roi_changed: bool = False | |
def load_model( | |
self, | |
model_choice: str = "y8small", | |
conf_threshold: float = 0.25, | |
iou_threshold: float = 0.7, | |
use_FP16: bool = False, | |
use_stream_buffer: bool = False | |
) -> None: | |
""" | |
load the YOLOv8 model of choice | |
""" | |
if model_choice not in self.model_dict: | |
logging.warning( | |
f'\"{model_choice}\" not found in the model_dict, use ' | |
f'\"{self.model_dict[self.model_choice_default]}\" instead!' | |
) | |
self.model_choice = self.model_choice_default | |
else: | |
self.model_choice = model_choice | |
self.model = YOLO(f"{self.model_dict[self.model_choice]}") | |
# push the model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if device == "cuda": | |
torch.cuda.set_device(0) | |
self.model.to(device) | |
logging.info( | |
f"{self.model_dict[self.model_choice]} loaded using " | |
f"torch w GPU0" | |
) | |
else: | |
logging.info( | |
f"{self.model_dict[self.model_choice]} loaded using CPU" | |
) | |
# setup some configs | |
self.conf_threshold = conf_threshold if conf_threshold > 0.0 else 0.25 # noqa | |
self.iou_threshold = iou_threshold if iou_threshold > 0.0 else 0.7 | |
self.use_FP16 = use_FP16 | |
self.use_stream_buffer = use_stream_buffer | |
logging.info( | |
f"{self.model_choice}: conf={self.conf_threshold:.2f} | " | |
f"iou={self.iou_threshold:.2f} | FP16={self.use_FP16} | " | |
f"stream_buffer={self.use_stream_buffer}" | |
) | |
def select_cam_loc( | |
self, | |
cam_loc_key: str = "Peace Bridge US", | |
cam_loc_val: str = "https://www.youtube.com/watch?v=9En2186vo5g" | |
) -> None: | |
""" | |
select camera video feed from url_dict, or set as a new url | |
""" | |
if (bool(cam_loc_key) is False or bool(cam_loc_val) is False): | |
self.cam_loc = self.cam_loc_default | |
logging.warning( | |
f'input cam_loc_key, cam_loc_val pair invalid, use default ' | |
f'{{{self.cam_loc_default}: ' | |
f'{self.url_dict[self.cam_loc_default]}}}' | |
) | |
elif cam_loc_key not in self.url_dict: | |
if try_site(self.url_dict[self.cam_loc]): | |
self.url_dict.update({cam_loc_key: cam_loc_val}) | |
self.cam_loc = cam_loc_key | |
logging.info( | |
f'input cam_loc key:val pair is new and playable, add ' | |
f'{{{cam_loc_key}:{cam_loc_val}}} into url_dict' | |
) | |
else: | |
self.cam_loc = self.cam_loc_default | |
logging.warning( | |
f'input cam_loc key:val pair is new but not playable, ' | |
f'roll back to default {{{self.cam_loc_default}: ' | |
f'{self.url_dict[self.cam_loc_default]}}}' | |
) | |
self.cam_loc = self.cam_loc_default | |
else: | |
self.cam_loc = cam_loc_key | |
logging.info( | |
f'use {{{self.cam_loc}: {self.url_dict[self.cam_loc]}}} as source' | |
) | |
def select_obj_class_id( | |
self, | |
obj_names: list[str] = [ | |
"person", "bicycle", "car", "motorcycle", "airplane", "bus", | |
"train", "truck", "boat", "traffic light", "fire hydrant", | |
"stop sign", "parking meter" | |
] | |
) -> None: | |
""" | |
select object class id list based on the input obj_names str list | |
""" | |
if (bool(obj_names) is False): | |
self.obj_class_id = self.obj_class_id_default | |
logging.warning( | |
f'input obj_names invalid, use default id {self.obj_class_id_default}' | |
) | |
else: | |
obj_class_id = [] | |
for name in obj_names: | |
if name in list(self.obj_dict.keys()): | |
obj_class_id.append(self.obj_dict[name]) | |
if (len(obj_class_id) == 0): | |
self.obj_class_id = self.obj_class_id_default | |
logging.warning( | |
f'input obj_names invalid, use default id ' | |
f'{self.obj_class_id_default}' | |
) | |
else: | |
self.obj_class_id = obj_class_id | |
logging.info(f'object class id set as {self.obj_class_id}') | |
# def set_roi(self, roi_height: int = 360, roi_thickness_half: int = 30): | |
def set_roi(self, roi_height: int = 360): | |
if (roi_height < 120 or roi_height > 600): | |
self.roi_height = int(self.FRAME_HEIGHT / 2) | |
logging.warning( | |
f'roi_height invalid, use default {int(self.FRAME_HEIGHT / 2)}' | |
) | |
else: | |
self.roi_height = roi_height | |
logging.info(f'roi_height is set at {self.roi_height}') | |
self.roi_thickness_half = self.roi_thickness_half_default | |
''' | |
if ( | |
roi_thickness_half > 0 and | |
roi_thickness_half < int(self.FRAME_HEIGHT / 2) | |
): | |
if (self.roi_height + roi_thickness_half > self.FRAME_HEIGHT): | |
self.roi_thickness_half = self.FRAME_HEIGHT - self.roi_height | |
elif (self.roi_height - roi_thickness_half < 0): | |
self.roi_thickness_half = self.roi_height | |
else: | |
self.roi_thickness_half = roi_thickness_half | |
logging.info( | |
f'roi_thickness_half is set at {self.roi_thickness_half}' | |
) | |
else: | |
self.roi_thickness_half = self.roi_thickness_half_default | |
logging.warning('roi_half_thickness invalid, use default 30') | |
''' | |
def set_frame_reduction(self, frame_reduction: int = 35): | |
if (frame_reduction < 0 or frame_reduction > 100): | |
self.frame_reduction = 35 | |
logging.warning( | |
f'frame_reduction:{frame_reduction} invalid, ' | |
f'use default value 35' | |
) | |
else: | |
self.frame_reduction = frame_reduction | |
logging.info(f'frame_reduction is set at {self.frame_reduction}') | |
async def frame0_producer(self): | |
""" | |
!!! define your original video source here !!! | |
Yields: | |
_type_: an image frame as a bytestring output from the producer | |
""" | |
while True: | |
if self._is_running: | |
if self.stream0 is None: | |
try: | |
# Start the stream, set desired resolution to be 720p | |
options = {"STREAM_RESOLUTION": "720p"} | |
self.stream0 = CamGear( | |
source=self.url_dict[self.cam_loc], | |
colorspace=None, | |
stream_mode=True, | |
logging=True, | |
**options | |
).start() | |
except Exception: | |
# Start the stream, set best resolution | |
self.stream0 = CamGear( | |
source=self.url_dict[self.cam_loc], | |
colorspace=None, | |
stream_mode=True, | |
logging=True | |
).start() | |
logging.warning( | |
f"failed to connect {self.url_dict[self.cam_loc]} " | |
f"at 720p resolution, use best resolution" | |
) | |
try: | |
# loop over frames | |
while (self.stream0 is not None and self._is_running): | |
frame = self.stream0.read() | |
if frame is None: | |
frame = (np.random.standard_normal([ | |
self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
]) * 255).astype(np.uint8) | |
elif frame.shape != (self.FRAME_HEIGHT, self.FRAME_WIDTH, 3): | |
frame = cv2.resize(frame, (self.FRAME_HEIGHT, self.FRAME_WIDTH)) | |
# do something with your OpenCV frame here | |
draw_text( | |
img=frame, | |
text=datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S") + " PDT", | |
pos=(int(self.FRAME_WIDTH - 500), 50), | |
font=cv2.FONT_HERSHEY_SIMPLEX, | |
font_scale=1, | |
font_thickness=2, | |
line_type=cv2.LINE_AA, | |
text_color=(0, 255, 255), | |
text_color_bg=(0, 0, 0), | |
) | |
# reducer frame size for performance, percentage int | |
frame = await reducer( | |
frame, percentage=self.frame_reduction | |
) | |
# handle JPEG encoding & yield frame in byte format | |
img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
yield ( | |
b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
img_encoded + b"\r\n" | |
) | |
await asyncio.sleep(0.00001) | |
if self.stream0 is not None: | |
self.stream0.stop() | |
while self.stream0.read() is not None: | |
continue | |
self.stream0 = None | |
self._is_running = False | |
except asyncio.CancelledError: | |
if self.stream0 is not None: | |
self.stream0.stop() | |
while self.stream0.read() is not None: | |
continue | |
self.stream0 = None | |
self._is_running = False | |
logging.warning( | |
"client disconneted in frame0_producer" | |
) | |
frame = (np.random.standard_normal([ | |
self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
]) * 255).astype(np.uint8) | |
frame = await reducer( | |
frame, percentage=self.frame_reduction | |
) | |
img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
logging.info( | |
f"_is_running is {self._is_running} in frame0_producer" | |
) | |
yield ( | |
b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
img_encoded + b"\r\n" | |
) | |
await asyncio.sleep(0.00001) | |
else: | |
if self._is_running is True: | |
pass | |
frame = (np.random.standard_normal([ | |
self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
]) * 255).astype(np.uint8) | |
frame = await reducer( | |
frame, percentage=self.frame_reduction | |
) | |
img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
logging.info( | |
f"_is_running is {self._is_running} in frame0_producer" | |
) | |
yield ( | |
b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
img_encoded + b"\r\n" | |
) | |
await asyncio.sleep(0.00001) | |
async def frame1_producer(self): | |
""" | |
!!! define your processed video producer here !!! | |
Yields: | |
_type_: an image frame as a bytestring output from the producer | |
""" | |
while True: | |
if self._is_running: | |
if self.stream1 is None: | |
try: | |
# Start the stream, set desired quality as 720p | |
options = {"STREAM_RESOLUTION": "720p"} | |
self.stream1 = CamGear( | |
source=self.url_dict[self.cam_loc], | |
colorspace=None, | |
stream_mode=True, | |
logging=True, | |
**options | |
).start() | |
except Exception: | |
# Start the stream, use the best resolution | |
self.stream1 = CamGear( | |
source=self.url_dict[self.cam_loc], | |
colorspace=None, | |
stream_mode=True, | |
logging=True | |
).start() | |
logging.warning( | |
f"failed to connect {self.url_dict[self.cam_loc]} " | |
f"at 720p resolution, use best resolution" | |
) | |
if (self._is_tracking and self.stream1 is not None): | |
if self.counter is None or self._roi_changed: | |
# setup object counter & speed estimator | |
region_points = [ | |
(5, -self.roi_thickness_half + self.roi_height), | |
(5, self.roi_thickness_half + self.roi_height), | |
( | |
self.FRAME_WIDTH - 5, | |
self.roi_thickness_half + self.roi_height | |
), | |
( | |
self.FRAME_WIDTH - 5, | |
-self.roi_thickness_half + self.roi_height | |
), | |
] | |
self.counter = object_counter.ObjectCounter() | |
self.counter.set_args( | |
view_img=False, | |
reg_pts=region_points, | |
classes_names=self.model.names, | |
draw_tracks=False, | |
draw_boxes=False, | |
draw_reg_pts=True, | |
) | |
self._roi_changed = False | |
if self.speed_obj is None or self._roi_changed: | |
# Init speed estimator | |
line_points = [ | |
(5, self.roi_height), | |
(self.FRAME_WIDTH - 5, self.roi_height) | |
] | |
self.speed_obj = speed_estimation.SpeedEstimator() | |
self.speed_obj.set_args( | |
reg_pts=line_points, | |
names=self.model.names, | |
view_img=False | |
) | |
self._roi_changed = False | |
try: | |
while (self.stream1 is not None and self._is_running): | |
if self._roi_changed: | |
# setup object counter & speed estimator | |
region_points = [ | |
(5, -self.roi_thickness_half + self.roi_height), | |
(5, self.roi_thickness_half + self.roi_height), | |
( | |
self.FRAME_WIDTH - 5, | |
self.roi_thickness_half + self.roi_height | |
), | |
( | |
self.FRAME_WIDTH - 5, | |
-self.roi_thickness_half + self.roi_height | |
), | |
] | |
self.counter = object_counter.ObjectCounter() | |
self.counter.set_args( | |
view_img=False, | |
reg_pts=region_points, | |
classes_names=self.model.names, | |
draw_tracks=False, | |
draw_boxes=False, | |
draw_reg_pts=True, | |
) | |
# Init speed estimator | |
line_points = [ | |
(5, self.roi_height), | |
(self.FRAME_WIDTH - 5, self.roi_height) | |
] | |
self.speed_obj = speed_estimation.SpeedEstimator() | |
self.speed_obj.set_args( | |
reg_pts=line_points, | |
names=self.model.names, | |
view_img=False | |
) | |
self._roi_changed = False | |
# read frame from provided source | |
frame = self.stream1.read() | |
if frame is None: | |
frame = (np.random.standard_normal([ | |
self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
]) * 255).astype(np.uint8) | |
elif frame.shape != (self.FRAME_HEIGHT, self.FRAME_WIDTH, 3): | |
frame = cv2.resize(frame, (self.FRAME_WIDTH, self.FRAME_HEIGHT)) | |
# do something with your OpenCV frame here | |
draw_text( | |
img=frame, | |
text=datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S") + " PDT", | |
pos=(self.FRAME_WIDTH - 500, 50), | |
font=cv2.FONT_HERSHEY_SIMPLEX, | |
font_scale=1, | |
font_thickness=2, | |
line_type=cv2.LINE_AA, | |
text_color=(0, 255, 255), | |
text_color_bg=(0, 0, 0), | |
) | |
frame_tagged = frame | |
if ( | |
self._is_tracking and self.model is not None | |
and self.speed_obj is not None | |
and self.counter is not None | |
and self._roi_changed is False | |
): | |
# YOLOv8 tracking, persisting tracks between frames | |
results = self.model.track( | |
source=frame, | |
classes=self.obj_class_id, | |
conf=self.conf_threshold, | |
iou=self.iou_threshold, | |
half=self.use_FP16, | |
stream_buffer=self.use_stream_buffer, | |
persist=True, | |
show=False, | |
verbose=self.YOLO_VERBOSE | |
) | |
if results[0].boxes.id is None: | |
pass | |
else: | |
self.speed_obj.estimate_speed( | |
frame_tagged, results | |
) | |
self.counter.start_counting( | |
frame_tagged, results | |
) | |
# reducer frames size for performance, int percentage | |
frame_tagged = await reducer( | |
frame=frame_tagged, | |
percentage=self.frame_reduction | |
) | |
# handle JPEG encoding & yield frame in byte format | |
img_encoded = \ | |
cv2.imencode(".jpg", frame_tagged)[1].tobytes() | |
yield ( | |
b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
img_encoded + b"\r\n" | |
) | |
await asyncio.sleep(0.00001) | |
if self.stream1 is not None: | |
self.stream1.stop() | |
while self.stream1.read() is not None: | |
continue | |
self.stream1 = None | |
self._is_tracking = False | |
self._is_running = False | |
except asyncio.CancelledError: | |
if self.stream1 is not None: | |
self.stream1.stop() | |
while self.stream1.read() is not None: | |
continue | |
self.stream1 = None | |
self._is_tracking = False | |
self._is_running = False | |
logging.warning( | |
"client disconnected in frame1_producer" | |
) | |
frame = (np.random.standard_normal([ | |
self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
]) * 255).astype(np.uint8) | |
frame = await reducer( | |
frame, percentage=self.frame_reduction | |
) | |
img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
logging.info( | |
f"_is_running is {self._is_running} in frame0_producer" | |
) | |
yield ( | |
b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
img_encoded + b"\r\n" | |
) | |
await asyncio.sleep(0.00001) | |
else: | |
if self._is_running is True: | |
pass | |
frame = (np.random.standard_normal([ | |
self.FRAME_HEIGHT, self.FRAME_WIDTH, 3 | |
]) * 255).astype(np.uint8) | |
# reducer frame size for more performance, percentage int | |
frame = await reducer(frame, percentage=self.frame_reduction) | |
# handle JPEG encoding & yield frame in byte format | |
img_encoded = cv2.imencode(".jpg", frame)[1].tobytes() | |
yield ( | |
b"--frame\r\nContent-Type:video/jpeg2000\r\n\r\n" + | |
img_encoded + b"\r\n" | |
) | |
await asyncio.sleep(0.00001) | |
async def custom_video_response(self, scope): | |
""" | |
Return a async video streaming response for `frame1_producer` generator | |
Tip1: use BackgroundTask to handle the async cleanup | |
https://github.com/tiangolo/fastapi/discussions/11022 | |
Tip2: use is_disconnected to check client disconnection | |
https://www.starlette.io/requests/#body | |
https://github.com/encode/starlette/pull/320/files/d56c917460a1e6488e1206c428445c39854859c1 | |
""" | |
assert scope["type"] in ["http", "https"] | |
await asyncio.sleep(0.00001) | |
return StreamingResponse( | |
content=self.frame1_producer(), | |
media_type="multipart/x-mixed-replace; boundary=frame" | |
) | |
async def models(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
if len(self.model_dict) == 0: | |
template = "partials/ack.html" | |
table_contents = ["model list unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
else: | |
template = "partials/yolo_models.html" | |
table_contents = make_table_from_dict( | |
self.model_dict, self.model_choice | |
) | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
await asyncio.sleep(0.001) | |
return response | |
async def urls(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
if len(self.url_dict) == 0: | |
template = "partials/ack.html" | |
table_contents = ["streaming url list unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
else: | |
template = "partials/camera_streams.html" | |
table_contents = make_table_from_dict(self.url_dict, self.cam_loc) | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
await asyncio.sleep(0.01) | |
return response | |
async def objects(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
if len(self.obj_dict) == 0: | |
template = "partials/ack.html" | |
table_contents = ["object list unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
else: | |
template = "partials/object_list.html" | |
table_contents = make_table_from_dict_multiselect( | |
self.obj_dict, self.obj_class_id | |
) | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
await asyncio.sleep(0.001) | |
return response | |
async def geturl(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
if len(self.url_dict) == 0: | |
template = "partials/ack.html" | |
table_contents = ["streaming url list unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
else: | |
template = "partials/ack.html" | |
if self.cam_loc in self.url_dict.keys(): | |
table_contents = [f"{self.cam_loc} selected"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=201 | |
) | |
else: | |
table_contents = [ | |
f"{self.cam_loc} is not in the registered url_list" | |
] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-url-ack' | |
await asyncio.sleep(0.01) | |
return response | |
async def addurl(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
try: | |
req_json = await request.json() | |
except RuntimeError: | |
template = "partials/ack.html" | |
table_contents = ["receive channel unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
response.headers['Hx-Retarget'] = '#add-url-ack' | |
await asyncio.sleep(0.01) | |
return response | |
if ( | |
"payload" in req_json | |
and "CamLoc" in req_json["payload"] and "URL" in req_json["payload"] | |
): | |
cam_loc = req_json["payload"]["CamLoc"] | |
cam_url = req_json["payload"]["URL"] | |
if cam_loc != "" and cam_url != "": | |
if try_site(cam_url) is False: | |
template = "partials/ack.html" | |
table_contents = ["invalid video URL!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
response.headers['Hx-Retarget'] = '#add-url-ack' | |
else: | |
self.select_cam_loc( | |
cam_loc_key=cam_loc, cam_loc_val=cam_url | |
) | |
template = "partials/camera_streams.html" | |
table_contents = make_table_from_dict( | |
self.url_dict, self.cam_loc | |
) | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=201 | |
) | |
else: | |
template = "partials/ack.html" | |
table_contents = ["empty or invalid inputs!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
response.headers['Hx-Retarget'] = '#add-url-ack' | |
else: | |
template = "partials/ack.html" | |
table_contents = ["invalid POST request!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
response.headers['Hx-Retarget'] = '#add-url-ack' | |
await asyncio.sleep(0.01) | |
return response | |
async def seturl(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
template = "partials/ack.html" | |
try: | |
req_json = await request.json() | |
except RuntimeError: | |
table_contents = ["receive channel unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-url-ack' | |
await asyncio.sleep(0.01) | |
return response | |
if ("payload" in req_json and "cam_url" in req_json["payload"]): | |
logging.info( | |
f"seturl: _is_running = {self._is_running}, " | |
f"_is_tracking = {self._is_tracking}" | |
) | |
if (self._is_running is True or self._is_tracking is True): | |
table_contents = ["turn off streaming and tracking before \ | |
setting a new camera stream!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-url-ack' | |
else: | |
cam_url = req_json["payload"]["cam_url"] | |
url_list = list(filter( | |
lambda x: self.url_dict[x] == cam_url, self.url_dict | |
)) | |
if len(url_list) > 0: | |
self.cam_loc = url_list[0] | |
table_contents = [f"{self.cam_loc} selected"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=201 | |
) | |
else: | |
table_contents = [ | |
f"{cam_url} is not in the registered url_list" | |
] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-url-ack' | |
else: | |
table_contents = ["invalid POST request!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-url-ack' | |
await asyncio.sleep(0.01) | |
return response | |
async def getmodel(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
if len(self.model_dict) == 0: | |
template = "partials/ack.html" | |
table_contents = ["model list unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
else: | |
template = "partials/ack.html" | |
if self.model_choice in self.model_dict.keys(): | |
table_contents = [f"{self.model_choice} selected"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=201 | |
) | |
else: | |
table_contents = [ | |
f"{self.model_choice} is not in the registered model_list" | |
] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-url-ack' | |
await asyncio.sleep(0.01) | |
return response | |
async def setmodel(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
template = "partials/ack.html" | |
try: | |
req_json = await request.json() | |
except RuntimeError: | |
table_contents = ["receive channel unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
await asyncio.sleep(0.01) | |
return response | |
if ("payload" in req_json and "model_path" in req_json["payload"]): | |
logging.info( | |
f"setmodel: _is_running = {self._is_running}, " | |
f"_is_tracking = {self._is_tracking}" | |
) | |
if (self._is_tracking is True): | |
table_contents = ["turn off tracking before setting a new \ | |
YOLO model!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
else: | |
model_path = req_json["payload"]["model_path"] | |
model_list = list(filter( | |
lambda x: self.model_dict[x] == model_path, self.model_dict | |
)) | |
if len(model_list) > 0: | |
self.model_choice = model_list[0] | |
self.load_model( | |
model_choice=self.model_choice, | |
conf_threshold=self.conf_threshold, | |
iou_threshold=self.iou_threshold, | |
use_FP16=self.use_FP16, | |
use_stream_buffer=self.use_stream_buffer | |
) | |
table_contents = [f"{self.model_choice} selected"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=201 | |
) | |
else: | |
table_contents = [ | |
f"{model_path} is not in the registered model_list" | |
] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
else: | |
table_contents = ["invalid POST request!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
# response.headers['Hx-Retarget'] = '#set-model-ack' | |
await asyncio.sleep(0.01) | |
return response | |
async def selectobjects(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
template = "partials/ack.html" | |
try: | |
req_json = await request.json() | |
except RuntimeError: | |
table_contents = ["receive channel unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
await asyncio.sleep(0.01) | |
return response | |
if ("payload" in req_json and "object_id" in req_json["payload"]): | |
logging.info(f"requested_ids: {req_json['payload']}") | |
req_ids = req_json["payload"]["object_id"] | |
if len(req_ids) > 0: | |
self.obj_class_id = [ | |
int(id) for id in req_ids | |
if int(id) in self.obj_dict.values() | |
] | |
if len(self.obj_class_id) > 0: | |
table_contents = [ | |
f"{len(self.obj_class_id)} object types selected" | |
] | |
else: | |
self.obj_class_id = self.obj_class_id_default | |
table_contents = [ | |
"invalid objects selection, use default object types" | |
] | |
else: | |
table_contents = ["invalid POST request! need at least one object type"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
await asyncio.sleep(0.01) | |
return response | |
async def setroi(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
template = "partials/ack.html" | |
try: | |
req_json = await request.json() | |
except RuntimeError: | |
table_contents = ["receive channel unavailable!"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
await asyncio.sleep(0.01) | |
return response | |
if ("payload" in req_json and "roi_height" in req_json["payload"]): | |
logging.info(f"{req_json['payload']}") | |
req_height = (int)(req_json["payload"]["roi_height"]) | |
if ( | |
req_height >= 120 and req_height <= 600 and | |
req_height < self.FRAME_HEIGHT | |
): | |
self.roi_height = self.FRAME_HEIGHT - req_height | |
table_contents = [ | |
f"roi_height set at " | |
f"{self.FRAME_HEIGHT - self.roi_height}px" | |
] | |
else: | |
self.roi_height = self.roi_height_default | |
table_contents = [ | |
f"invalid roi_height request, use default" | |
f"{self.FRAME_HEIGHT - self.roi_height_default}px" | |
] | |
self._roi_changed = True | |
else: | |
table_contents = ["invalid POST request! need a valid roi_height"] | |
context = {"request": request, "table": table_contents} | |
response = templates.TemplateResponse( | |
template, context, status_code=200 | |
) | |
await asyncio.sleep(0.01) | |
return response | |
async def streamswitch(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
template = "partials/ack.html" | |
try: | |
req_json = await request.json() | |
except RuntimeError: | |
context = { | |
"request": request, "table": ["receive channel unavailable!"] | |
} | |
status_code = 200 | |
await asyncio.sleep(0.01) | |
return templates.TemplateResponse( | |
template, context, status_code=status_code | |
) | |
if "payload" in req_json: | |
logging.info(f"payload = {req_json['payload']}") | |
if ( | |
"stream_switch" in req_json["payload"] | |
and req_json["payload"]["stream_switch"] == "on" | |
): | |
self._is_running = True | |
self._is_tracking = False | |
table_contents = ["on"] | |
status_code = 201 | |
else: | |
self._is_running = False | |
self._is_tracking = False | |
table_contents = ["off"] | |
status_code = 201 | |
else: | |
table_contents = ["invalid POST request!"] | |
status_code = 200 | |
context = {"request": request, "table": table_contents} | |
await asyncio.sleep(0.1) | |
return templates.TemplateResponse( | |
template, context, status_code=status_code | |
) | |
async def trackingswitch(self, request: HtmxRequest) -> Response: | |
# assert (htmx := request.scope["htmx"]) | |
template = "partials/ack.html" | |
try: | |
req_json = await request.json() | |
except RuntimeError: | |
context = { | |
"request": request, "table": ["receive channel unavailable!"] | |
} | |
status_code = 200 | |
await asyncio.sleep(0.01) | |
return templates.TemplateResponse( | |
template, context, status_code=status_code | |
) | |
if "payload" in req_json: | |
logging.info(f"payload = {req_json['payload']}") | |
if ( | |
"tracking_switch" in req_json["payload"] | |
and req_json["payload"]["tracking_switch"] == "on" | |
): | |
self._is_tracking = True and self._is_running | |
else: | |
self._is_tracking = False | |
if self._is_tracking: | |
table_contents = ["on"] | |
status_code = 201 | |
# setup object counter & speed estimator | |
region_points = [ | |
(5, -20 + self.roi_height), | |
(5, 20 + self.roi_height), | |
(self.FRAME_WIDTH - 5, 20 + self.roi_height), | |
(self.FRAME_WIDTH - 5, -20 + self.roi_height), | |
] | |
self.counter = object_counter.ObjectCounter() | |
self.counter.set_args( | |
view_img=False, | |
reg_pts=region_points, | |
classes_names=self.model.names, | |
draw_tracks=False, | |
draw_boxes=False, | |
draw_reg_pts=True, | |
) | |
# Init speed estimator | |
line_points = [ | |
(5, self.roi_height), | |
(self.FRAME_WIDTH - 5, self.roi_height) | |
] | |
self.speed_obj = speed_estimation.SpeedEstimator() | |
self.speed_obj.set_args( | |
reg_pts=line_points, | |
names=self.model.names, | |
view_img=False | |
) | |
else: | |
table_contents = ["off"] | |
status_code = 201 | |
else: | |
table_contents = ["invalid POST request!"] | |
status_code = 200 | |
context = {"request": request, "table": table_contents} | |
await asyncio.sleep(0.1) | |
return templates.TemplateResponse( | |
template, context, status_code=status_code | |
) | |
async def sse_incounts(self, request: Request): | |
async def event_generator(): | |
_stop_sse = False | |
while True: | |
# If client closes connection, stop sending events | |
if await request.is_disconnected(): | |
yield { | |
"event": "evt_in_counts", | |
"id": datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S"), | |
"retry": RETRY_TIMEOUT_MILSEC, | |
"data": "..." | |
} | |
break | |
if self._is_running: | |
if self._is_tracking: | |
if _stop_sse is True: | |
_stop_sse = False | |
incounts_msg = self.counter.incounts_updated() | |
if (self.counter is not None and incounts_msg): | |
yield { | |
"event": "evt_in_counts", | |
"id": datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S"), | |
"retry": RETRY_TIMEOUT_MILSEC, | |
"data": f"{self.counter.in_counts}" | |
} | |
else: | |
if _stop_sse is False: | |
yield { | |
"event": "evt_in_counts", | |
"id": datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S"), | |
"retry": RETRY_TIMEOUT_MILSEC, | |
"data": "---" | |
} | |
_stop_sse = True | |
await asyncio.sleep(EVT_STREAM_DELAY_SEC) | |
return EventSourceResponse(event_generator()) | |
async def sse_outcounts(self, request: Request): | |
async def event_generator(): | |
_stop_sse = False | |
while True: | |
# If client closes connection, stop sending events | |
if await request.is_disconnected(): | |
yield { | |
"event": "evt_out_counts", | |
"id": datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S"), | |
"retry": RETRY_TIMEOUT_MILSEC, | |
"data": "..." | |
} | |
break | |
if self._is_running: | |
if self._is_tracking: | |
if _stop_sse is True: | |
_stop_sse = False | |
outcounts_msg = self.counter.outcounts_updated() | |
if (self.counter is not None and outcounts_msg): | |
yield { | |
"event": "evt_out_counts", | |
"id": datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S"), | |
"retry": RETRY_TIMEOUT_MILSEC, | |
"data": f"{self.counter.out_counts}" | |
} | |
else: | |
if _stop_sse is False: | |
yield { | |
"event": "evt_out_counts", | |
"id": datetime.now( | |
tz=ZoneInfo("America/Los_Angeles") | |
).strftime("%m/%d/%Y %H:%M:%S"), | |
"retry": RETRY_TIMEOUT_MILSEC, | |
"data": "---" | |
} | |
_stop_sse = True | |
await asyncio.sleep(EVT_STREAM_DELAY_SEC) | |
return EventSourceResponse(event_generator()) | |
# is_huggingface = False | |
# define the host url and port for webgear server | |
# HOST_WEBGEAR, PORT_WEBGEAR = "localhost", 8080 | |
# instantiate a demo case | |
demo_case = DemoCase(YOLO_VERBOSE=False) | |
demo_case.set_frame_reduction(frame_reduction=35) | |
demo_case.load_model( | |
model_choice="y8small", | |
conf_threshold=0.1, | |
iou_threshold=0.6, | |
use_FP16=False, | |
use_stream_buffer=True | |
) | |
logging.info(f"url_dict: {demo_case.url_dict}") | |
logging.info(f"model_dict: {demo_case.model_dict}") | |
logging.info(f"obj_dict: {demo_case.obj_dict}") | |
logging.info(f"obj_class_id: {demo_case.obj_class_id}") | |
# logging.info(f"model.names: {demo_case.model.names}") | |
# setup webgear server | |
options = { | |
"custom_data_location": "./", | |
} | |
web = WebGear( | |
logging=True, **options | |
) | |
# config webgear server | |
web.config["generator"] = demo_case.frame1_producer | |
web.config["middleware"] = [Middleware(HtmxMiddleware)] | |
web.routes.append(Mount("/static", static, name="static")) | |
# web.routes.append( | |
# Route("/video1", endpoint=demo_case.custom_video_response) | |
# ) | |
routes_dict = { | |
"models": (demo_case.models, ["GET"]), | |
"getmodel": (demo_case.getmodel, ["GET"]), | |
"setmodel": (demo_case.setmodel, ["POST"]), | |
"urls": (demo_case.urls, ["GET"]), | |
"addurl": (demo_case.addurl, ["POST"]), | |
"geturl": (demo_case.geturl, ["GET"]), | |
"seturl": (demo_case.seturl, ["POST"]), | |
"objects": (demo_case.objects, ["GET"]), | |
"selectobjects": (demo_case.selectobjects, ["POST"]), | |
"setroi": (demo_case.setroi, ["POST"]), | |
"streamswitch": (demo_case.streamswitch, ["POST"]), | |
"trackingswitch": (demo_case.trackingswitch, ["POST"]), | |
} | |
for k, v in routes_dict.items(): | |
web.routes.append( | |
Route(path=f"/{k}", endpoint=v[0], name=k, methods=v[1]) | |
) | |
web.routes.append(Route( | |
path="/sseincounts", | |
endpoint=demo_case.sse_incounts, | |
name="sseincounts" | |
)) | |
web.routes.append(Route( | |
path="/sseoutcounts", | |
endpoint=demo_case.sse_outcounts, | |
name="sseoutcounts" | |
)) | |
# if is_huggingface is False: | |
# # run this app on Uvicorn server at address http://localhost:8080/ | |
# uvicorn.run( | |
# web(), host=HOST_WEBGEAR, port=PORT_WEBGEAR, log_level="info" | |
# ) | |
# # close app safely | |
# web.shutdown() | |
# | |
# or launch it using cli -- | |
# uvicorn webapp:web --host "localhost" --port 8080 --reload | |