|
|
|
from typing import List |
|
|
|
import matplotlib.lines as mlines |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import numpy.typing as npt |
|
from nuplan.common.maps.maps_datatypes import TrafficLightStatusData, TrafficLightStatusType |
|
from nuplan.common.utils.interpolatable_state import InterpolatableState |
|
from shapely.geometry import Point |
|
|
|
from navsim.planning.metric_caching.metric_cache import MetricCache |
|
from navsim.planning.simulation.planner.pdm_planner.observation.pdm_occupancy_map import ( |
|
PDMDrivableMap, PDMCrosswalkIntersectionMap, |
|
) |
|
from navsim.planning.simulation.planner.pdm_planner.utils.pdm_path import PDMPath |
|
|
|
|
|
def calc_tl(trajectories, |
|
num_proposals, |
|
drivable_area_map: PDMDrivableMap, |
|
metric_cache: MetricCache, |
|
centerline: PDMPath, |
|
route_lane_ids, |
|
config) -> npt.NDArray: |
|
""" |
|
vocab_size = 4096 or 8192 |
|
trajectories: [ |
|
PDM-Closed Trajectory + vocab_size trajs, |
|
1+40 (current pose + 4 secs * 10Hz poses), |
|
11: StateIndex, navsim/planning/simulation/planner/pdm_planner/utils/pdm_enums.py] |
|
num_proposals: PDM-Closed Trajectory + vocab_size trajs |
|
""" |
|
|
|
result_scores = np.ones(num_proposals) |
|
|
|
|
|
|
|
timestamps = int(1 + 4 / 0.5) |
|
gt_traj_global = metric_cache.others['gt_traj_global'][:timestamps] |
|
traffic_lights: List[TrafficLightStatusData] = metric_cache.others['traffic_lights'] |
|
crosswalk_intersection: PDMCrosswalkIntersectionMap = metric_cache.others['crosswalk_intersection'] |
|
|
|
red_lanes = [] |
|
|
|
|
|
for tl_data in traffic_lights: |
|
is_red = tl_data.status == TrafficLightStatusType.RED |
|
lane_conn_id = str(tl_data.lane_connector_id) |
|
near_ego = lane_conn_id in drivable_area_map.tokens |
|
on_route = lane_conn_id in route_lane_ids |
|
|
|
if not (on_route and near_ego): |
|
continue |
|
red_lane = drivable_area_map[lane_conn_id] |
|
|
|
if_gt_intersects = crosswalk_intersection.points_in_dangerous_polygons(gt_traj_global[:, :2], red_lane).any() |
|
inferred_is_red = np.logical_not(if_gt_intersects) |
|
red_lanes.append((red_lane, lane_conn_id, is_red, inferred_is_red)) |
|
|
|
if inferred_is_red and is_red: |
|
|
|
intersected_mask = crosswalk_intersection.points_in_dangerous_polygons(trajectories[:, :, :2], red_lane).any((0, 2)) |
|
|
|
result_scores *= (1 - intersected_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return result_scores |
|
|