lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
# traffic light metric
from typing import List
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
from nuplan.common.maps.maps_datatypes import TrafficLightStatusData, TrafficLightStatusType
from nuplan.common.utils.interpolatable_state import InterpolatableState
from shapely.geometry import Point
from navsim.planning.metric_caching.metric_cache import MetricCache
from navsim.planning.simulation.planner.pdm_planner.observation.pdm_occupancy_map import (
PDMDrivableMap, PDMCrosswalkIntersectionMap,
)
from navsim.planning.simulation.planner.pdm_planner.utils.pdm_path import PDMPath
def calc_tl(trajectories,
num_proposals,
drivable_area_map: PDMDrivableMap,
metric_cache: MetricCache,
centerline: PDMPath,
route_lane_ids,
config) -> npt.NDArray:
"""
vocab_size = 4096 or 8192
trajectories: [
PDM-Closed Trajectory + vocab_size trajs,
1+40 (current pose + 4 secs * 10Hz poses),
11: StateIndex, navsim/planning/simulation/planner/pdm_planner/utils/pdm_enums.py]
num_proposals: PDM-Closed Trajectory + vocab_size trajs
"""
# [num_proposals]
result_scores = np.ones(num_proposals)
# 1. find trajectories that go into the intersection or crosswalk
# 2. find the tl status corresponding to the current centerline
# 3. if tl is red, set those trajectories found in step 1 to zero
timestamps = int(1 + 4 / 0.5)
gt_traj_global = metric_cache.others['gt_traj_global'][:timestamps]
traffic_lights: List[TrafficLightStatusData] = metric_cache.others['traffic_lights']
crosswalk_intersection: PDMCrosswalkIntersectionMap = metric_cache.others['crosswalk_intersection']
red_lanes = []
for tl_data in traffic_lights:
is_red = tl_data.status == TrafficLightStatusType.RED
lane_conn_id = str(tl_data.lane_connector_id)
near_ego = lane_conn_id in drivable_area_map.tokens
on_route = lane_conn_id in route_lane_ids
# only consider those on-route & nearby lights
if not (on_route and near_ego):
continue
red_lane = drivable_area_map[lane_conn_id]
# is_stop should be based on if gt traj intersects with filtered crosswalks / intersections
if_gt_intersects = crosswalk_intersection.points_in_dangerous_polygons(gt_traj_global[:, :2], red_lane).any()
inferred_is_red = np.logical_not(if_gt_intersects)
red_lanes.append((red_lane, lane_conn_id, is_red, inferred_is_red))
if inferred_is_red and is_red:
# intersected_mask
intersected_mask = crosswalk_intersection.points_in_dangerous_polygons(trajectories[:, :, :2], red_lane).any((0, 2))
# valid mask indicates which polygon actually intersects with the red lane
result_scores *= (1 - intersected_mask)
# debug
# fig, ax = plt.subplots()
# ax.plot(trajectories[0, :, 0], trajectories[0, :, 1], label='Centerline', color='blue')
# custom_legend_entries = []
# for lane_conn, lane_conn_id, is_red, is_stop in red_lanes:
# x, y = lane_conn.boundary.xy
# lines = ax.plot(x, y, label=f'Lane Conn {lane_conn_id}', linestyle='--')
# for line in lines:
# color = line.get_color()
# custom_legend_entries.append(
# mlines.Line2D([], [], color=color, linestyle='--',
# label=f'{lane_conn_id}:r{is_red} s{is_stop}'))
# token = str(metric_cache.file_path).split('/')[-2]
# ax.legend(handles=custom_legend_entries, bbox_to_anchor=(1.05, 1), loc='upper left')
# plt.tight_layout(rect=(0, 0, 0.75, 1))
# plt.savefig(f'/mnt/g/navsim_vis/tl_check/{token}_tl.png', bbox_inches='tight')
return result_scores