navsim_ours / navsim /agents /transfuser /transfuser_callback.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
import time
from typing import Any, Dict, Optional, Union
from PIL import ImageColor
import cv2
import numpy as np
import numpy.typing as npt
import pytorch_lightning as pl
import torch
import torchvision.utils as vutils
from nuplan.common.maps.abstract_map import SemanticMapLayer
from nuplan.common.actor_state.oriented_box import OrientedBox
from nuplan.common.actor_state.state_representation import StateSE2
from navsim.visualization.config import TAB_10, MAP_LAYER_CONFIG, AGENT_CONFIG
from navsim.agents.transfuser.transfuser_features import BoundingBox2DIndex
from navsim.agents.transfuser.transfuser_config import TransfuserConfig
class TransfuserCallback(pl.Callback):
def __init__(
self, config: TransfuserConfig, num_plots: int = 10, num_rows: int = 2, num_columns: int = 2
) -> None:
self._config = config
self._num_plots = num_plots
self._num_rows = num_rows
self._num_columns = num_columns
def on_validation_epoch_start(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
pass
def on_validation_epoch_end(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
device = lightning_module.device
val_data_iter = iter(trainer.val_dataloaders)
for idx_plot in range(self._num_plots):
features, targets, tokens = next(val_data_iter)
features, targets = dict_to_device(features, device), dict_to_device(targets, device)
with torch.no_grad():
predictions = lightning_module.agent.forward(features)
features, targets, predictions = (
dict_to_device(features, "cpu"),
dict_to_device(targets, "cpu"),
dict_to_device(predictions, "cpu"),
)
grid = self._visualize_model(features, targets, predictions)
trainer.logger.experiment.add_image(
f"val_plot_{idx_plot}", grid, global_step=trainer.current_epoch
)
def on_test_epoch_start(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
pass
def on_test_epoch_end(self, trainer: pl.Trainer, lightning_module: pl.LightningModule) -> None:
pass
def on_train_epoch_start(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
pass
def on_train_epoch_end(
self,
trainer: pl.Trainer,
lightning_module: pl.LightningModule,
unused: Optional[Any] = None,
) -> None:
pass
# device = lightning_module.device
# train_data_iter = iter(trainer.train_dataloader)
# for idx_plot in range(self._num_plots):
# features, targets, _ = next(train_data_iter)
# features, targets = dict_to_device(features, device), dict_to_device(targets, device)
# with torch.no_grad():
# predictions = lightning_module.agent.forward(features)
#
# features, targets, predictions = (
# dict_to_device(features, "cpu"),
# dict_to_device(targets, "cpu"),
# dict_to_device(predictions, "cpu"),
# )
# grid = self._visualize_model(features, targets, predictions)
# trainer.logger.experiment.add_image(
# f"train_plot_{idx_plot}", grid, global_step=trainer.current_epoch
# )
def _visualize_model(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
) -> torch.Tensor:
camera = features["camera_feature"].permute(0, 2, 3, 1).numpy()
bev = targets["bev_semantic_map"].numpy()
if features['lidar_feature'].shape[1] > 1:
lidar_map = features['lidar_feature'][:, -1].numpy()
else:
lidar_map = features["lidar_feature"].squeeze(1).numpy()
agent_labels = targets["agent_labels"].numpy()
agent_states = targets["agent_states"].numpy()
trajectory = targets["trajectory"].numpy()
pred_bev = predictions["bev_semantic_map"].argmax(1).numpy()
pred_agent_labels = predictions["agent_labels"].sigmoid().numpy()
pred_agent_states = predictions["agent_states"].numpy()
pred_trajectory = predictions["trajectory"].numpy()
plots = []
for sample_idx in range(self._num_rows * self._num_columns):
plot = np.zeros((256, 768, 3), dtype=np.uint8)
cam_stride = camera[sample_idx].shape[0] // 128
tmp = semantic_map_to_rgb(bev[sample_idx], self._config)
lidar_stride = tmp.shape[0] // 128
plot[:128, :512] = (camera[sample_idx] * 255).astype(np.uint8)[::cam_stride, ::cam_stride]
plot[128:, :256] = tmp[::lidar_stride, ::lidar_stride]
plot[128:, 256:512] = semantic_map_to_rgb(pred_bev[sample_idx], self._config)[::lidar_stride, ::lidar_stride]
agent_states_ = agent_states[sample_idx][agent_labels[sample_idx]]
pred_agent_states_ = pred_agent_states[sample_idx][pred_agent_labels[sample_idx] > 0.5]
plot[:, 512:] = lidar_map_to_rgb(
lidar_map[sample_idx],
agent_states_,
pred_agent_states_,
trajectory[sample_idx],
pred_trajectory[sample_idx],
self._config,
)[::lidar_stride, ::lidar_stride]
plots.append(torch.tensor(plot).permute(2, 0, 1))
return vutils.make_grid(plots, normalize=False, nrow=self._num_rows)
def dict_to_device(
dict: Dict[str, torch.Tensor], device: Union[torch.device, str]
) -> Dict[str, torch.Tensor]:
for key in dict.keys():
dict[key] = dict[key].to(device)
return dict
def semantic_map_to_rgb(
semantic_map: npt.NDArray[np.int64], config: TransfuserConfig
) -> npt.NDArray[np.uint8]:
height, width = semantic_map.shape[:2]
rgb_map = np.ones((height, width, 3), dtype=np.uint8) * 255
for label in range(1, config.num_bev_classes):
if config.bev_semantic_classes[label][0] == "linestring":
hex_color = MAP_LAYER_CONFIG[SemanticMapLayer.BASELINE_PATHS]["line_color"]
else:
layer = config.bev_semantic_classes[label][-1][0] # take color of first element
hex_color = (
AGENT_CONFIG[layer]["fill_color"]
if layer in AGENT_CONFIG.keys()
else MAP_LAYER_CONFIG[layer]["fill_color"]
)
rgb_map[semantic_map == label] = ImageColor.getcolor(hex_color, "RGB")
return rgb_map[::-1, ::-1]
def lidar_map_to_rgb(
lidar_map: npt.NDArray[np.int64],
agent_states: npt.NDArray[np.float32],
pred_agent_states: npt.NDArray[np.float32],
trajectory: npt.NDArray[np.float32],
pred_trajectory: npt.NDArray[np.float32],
config: TransfuserConfig,
) -> npt.NDArray[np.uint8]:
gt_color, pred_color = (0, 255, 0), (255, 0, 0)
point_size = 4
height, width = lidar_map.shape[:2]
def coords_to_pixel(coords):
pixel_center = np.array([[height / 2.0, width / 2.0]])
coords_idcs = (coords / config.bev_pixel_size) + pixel_center
return coords_idcs.astype(np.int32)
rgb_map = (lidar_map * 255).astype(np.uint8)
rgb_map = 255 - rgb_map[..., None].repeat(3, axis=-1)
for color, agent_state_array in zip(
[gt_color, pred_color], [agent_states, pred_agent_states]
):
for agent_state in agent_state_array:
agent_box = OrientedBox(
StateSE2(*agent_state[BoundingBox2DIndex.STATE_SE2]),
agent_state[BoundingBox2DIndex.LENGTH],
agent_state[BoundingBox2DIndex.WIDTH],
1.0,
)
exterior = np.array(agent_box.geometry.exterior.coords).reshape((-1, 1, 2))
exterior = coords_to_pixel(exterior)
exterior = np.flip(exterior, axis=-1)
cv2.polylines(rgb_map, [exterior], isClosed=True, color=color, thickness=2)
for color, traj in zip(
[gt_color, pred_color], [trajectory, pred_trajectory]
):
trajectory_indices = coords_to_pixel(traj[:,:2])
for x, y in trajectory_indices:
cv2.circle(rgb_map, (y, x), point_size, color, -1) # -1 fills the circle
return rgb_map[::-1, ::-1]