# traffic light metric from typing import List import matplotlib.lines as mlines import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt from nuplan.common.maps.maps_datatypes import TrafficLightStatusData, TrafficLightStatusType from nuplan.common.utils.interpolatable_state import InterpolatableState from shapely.geometry import Point from navsim.planning.metric_caching.metric_cache import MetricCache from navsim.planning.simulation.planner.pdm_planner.observation.pdm_occupancy_map import ( PDMDrivableMap, PDMCrosswalkIntersectionMap, ) from navsim.planning.simulation.planner.pdm_planner.utils.pdm_path import PDMPath def calc_tl(trajectories, num_proposals, drivable_area_map: PDMDrivableMap, metric_cache: MetricCache, centerline: PDMPath, route_lane_ids, config) -> npt.NDArray: """ vocab_size = 4096 or 8192 trajectories: [ PDM-Closed Trajectory + vocab_size trajs, 1+40 (current pose + 4 secs * 10Hz poses), 11: StateIndex, navsim/planning/simulation/planner/pdm_planner/utils/pdm_enums.py] num_proposals: PDM-Closed Trajectory + vocab_size trajs """ # [num_proposals] result_scores = np.ones(num_proposals) # 1. find trajectories that go into the intersection or crosswalk # 2. find the tl status corresponding to the current centerline # 3. if tl is red, set those trajectories found in step 1 to zero timestamps = int(1 + 4 / 0.5) gt_traj_global = metric_cache.others['gt_traj_global'][:timestamps] traffic_lights: List[TrafficLightStatusData] = metric_cache.others['traffic_lights'] crosswalk_intersection: PDMCrosswalkIntersectionMap = metric_cache.others['crosswalk_intersection'] red_lanes = [] for tl_data in traffic_lights: is_red = tl_data.status == TrafficLightStatusType.RED lane_conn_id = str(tl_data.lane_connector_id) near_ego = lane_conn_id in drivable_area_map.tokens on_route = lane_conn_id in route_lane_ids # only consider those on-route & nearby lights if not (on_route and near_ego): continue red_lane = drivable_area_map[lane_conn_id] # is_stop should be based on if gt traj intersects with filtered crosswalks / intersections if_gt_intersects = crosswalk_intersection.points_in_dangerous_polygons(gt_traj_global[:, :2], red_lane).any() inferred_is_red = np.logical_not(if_gt_intersects) red_lanes.append((red_lane, lane_conn_id, is_red, inferred_is_red)) if inferred_is_red and is_red: # intersected_mask intersected_mask = crosswalk_intersection.points_in_dangerous_polygons(trajectories[:, :, :2], red_lane).any((0, 2)) # valid mask indicates which polygon actually intersects with the red lane result_scores *= (1 - intersected_mask) # debug # fig, ax = plt.subplots() # ax.plot(trajectories[0, :, 0], trajectories[0, :, 1], label='Centerline', color='blue') # custom_legend_entries = [] # for lane_conn, lane_conn_id, is_red, is_stop in red_lanes: # x, y = lane_conn.boundary.xy # lines = ax.plot(x, y, label=f'Lane Conn {lane_conn_id}', linestyle='--') # for line in lines: # color = line.get_color() # custom_legend_entries.append( # mlines.Line2D([], [], color=color, linestyle='--', # label=f'{lane_conn_id}:r{is_red} s{is_stop}')) # token = str(metric_cache.file_path).split('/')[-2] # ax.legend(handles=custom_legend_entries, bbox_to_anchor=(1.05, 1), loc='upper left') # plt.tight_layout(rect=(0, 0, 0.75, 1)) # plt.savefig(f'/mnt/g/navsim_vis/tl_check/{token}_tl.png', bbox_inches='tight') return result_scores