navsim_ours / navsim /agents /transfuser /transfuser_callback.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
8.7 kB
import time
from typing import Any, Dict, Optional, Union
from PIL import ImageColor
import cv2
import numpy as np
import numpy.typing as npt
import pytorch_lightning as pl
import torch
import torchvision.utils as vutils
from nuplan.common.maps.abstract_map import SemanticMapLayer
from nuplan.common.actor_state.oriented_box import OrientedBox
from nuplan.common.actor_state.state_representation import StateSE2
from navsim.visualization.config import TAB_10, MAP_LAYER_CONFIG, AGENT_CONFIG
from navsim.agents.transfuser.transfuser_features import BoundingBox2DIndex
from navsim.agents.transfuser.transfuser_config import TransfuserConfig
class TransfuserCallback(pl.Callback):
def __init__(
self, config: TransfuserConfig, num_plots: int = 10, num_rows: int = 2, num_columns: int = 2
) -> None:
self._config = config
self._num_plots = num_plots
self._num_rows = num_rows
self._num_columns = num_columns
def on_validation_epoch_start(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
pass
def on_validation_epoch_end(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
device = lightning_module.device
val_data_iter = iter(trainer.val_dataloaders)
for idx_plot in range(self._num_plots):
features, targets, tokens = next(val_data_iter)
features, targets = dict_to_device(features, device), dict_to_device(targets, device)
with torch.no_grad():
predictions = lightning_module.agent.forward(features)
features, targets, predictions = (
dict_to_device(features, "cpu"),
dict_to_device(targets, "cpu"),
dict_to_device(predictions, "cpu"),
)
grid = self._visualize_model(features, targets, predictions)
trainer.logger.experiment.add_image(
f"val_plot_{idx_plot}", grid, global_step=trainer.current_epoch
)
def on_test_epoch_start(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
pass
def on_test_epoch_end(self, trainer: pl.Trainer, lightning_module: pl.LightningModule) -> None:
pass
def on_train_epoch_start(
self, trainer: pl.Trainer, lightning_module: pl.LightningModule
) -> None:
pass
def on_train_epoch_end(
self,
trainer: pl.Trainer,
lightning_module: pl.LightningModule,
unused: Optional[Any] = None,
) -> None:
pass
# device = lightning_module.device
# train_data_iter = iter(trainer.train_dataloader)
# for idx_plot in range(self._num_plots):
# features, targets, _ = next(train_data_iter)
# features, targets = dict_to_device(features, device), dict_to_device(targets, device)
# with torch.no_grad():
# predictions = lightning_module.agent.forward(features)
#
# features, targets, predictions = (
# dict_to_device(features, "cpu"),
# dict_to_device(targets, "cpu"),
# dict_to_device(predictions, "cpu"),
# )
# grid = self._visualize_model(features, targets, predictions)
# trainer.logger.experiment.add_image(
# f"train_plot_{idx_plot}", grid, global_step=trainer.current_epoch
# )
def _visualize_model(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
) -> torch.Tensor:
camera = features["camera_feature"].permute(0, 2, 3, 1).numpy()
bev = targets["bev_semantic_map"].numpy()
if features['lidar_feature'].shape[1] > 1:
lidar_map = features['lidar_feature'][:, -1].numpy()
else:
lidar_map = features["lidar_feature"].squeeze(1).numpy()
agent_labels = targets["agent_labels"].numpy()
agent_states = targets["agent_states"].numpy()
trajectory = targets["trajectory"].numpy()
pred_bev = predictions["bev_semantic_map"].argmax(1).numpy()
pred_agent_labels = predictions["agent_labels"].sigmoid().numpy()
pred_agent_states = predictions["agent_states"].numpy()
pred_trajectory = predictions["trajectory"].numpy()
plots = []
for sample_idx in range(self._num_rows * self._num_columns):
plot = np.zeros((256, 768, 3), dtype=np.uint8)
cam_stride = camera[sample_idx].shape[0] // 128
tmp = semantic_map_to_rgb(bev[sample_idx], self._config)
lidar_stride = tmp.shape[0] // 128
plot[:128, :512] = (camera[sample_idx] * 255).astype(np.uint8)[::cam_stride, ::cam_stride]
plot[128:, :256] = tmp[::lidar_stride, ::lidar_stride]
plot[128:, 256:512] = semantic_map_to_rgb(pred_bev[sample_idx], self._config)[::lidar_stride, ::lidar_stride]
agent_states_ = agent_states[sample_idx][agent_labels[sample_idx]]
pred_agent_states_ = pred_agent_states[sample_idx][pred_agent_labels[sample_idx] > 0.5]
plot[:, 512:] = lidar_map_to_rgb(
lidar_map[sample_idx],
agent_states_,
pred_agent_states_,
trajectory[sample_idx],
pred_trajectory[sample_idx],
self._config,
)[::lidar_stride, ::lidar_stride]
plots.append(torch.tensor(plot).permute(2, 0, 1))
return vutils.make_grid(plots, normalize=False, nrow=self._num_rows)
def dict_to_device(
dict: Dict[str, torch.Tensor], device: Union[torch.device, str]
) -> Dict[str, torch.Tensor]:
for key in dict.keys():
dict[key] = dict[key].to(device)
return dict
def semantic_map_to_rgb(
semantic_map: npt.NDArray[np.int64], config: TransfuserConfig
) -> npt.NDArray[np.uint8]:
height, width = semantic_map.shape[:2]
rgb_map = np.ones((height, width, 3), dtype=np.uint8) * 255
for label in range(1, config.num_bev_classes):
if config.bev_semantic_classes[label][0] == "linestring":
hex_color = MAP_LAYER_CONFIG[SemanticMapLayer.BASELINE_PATHS]["line_color"]
else:
layer = config.bev_semantic_classes[label][-1][0] # take color of first element
hex_color = (
AGENT_CONFIG[layer]["fill_color"]
if layer in AGENT_CONFIG.keys()
else MAP_LAYER_CONFIG[layer]["fill_color"]
)
rgb_map[semantic_map == label] = ImageColor.getcolor(hex_color, "RGB")
return rgb_map[::-1, ::-1]
def lidar_map_to_rgb(
lidar_map: npt.NDArray[np.int64],
agent_states: npt.NDArray[np.float32],
pred_agent_states: npt.NDArray[np.float32],
trajectory: npt.NDArray[np.float32],
pred_trajectory: npt.NDArray[np.float32],
config: TransfuserConfig,
) -> npt.NDArray[np.uint8]:
gt_color, pred_color = (0, 255, 0), (255, 0, 0)
point_size = 4
height, width = lidar_map.shape[:2]
def coords_to_pixel(coords):
pixel_center = np.array([[height / 2.0, width / 2.0]])
coords_idcs = (coords / config.bev_pixel_size) + pixel_center
return coords_idcs.astype(np.int32)
rgb_map = (lidar_map * 255).astype(np.uint8)
rgb_map = 255 - rgb_map[..., None].repeat(3, axis=-1)
for color, agent_state_array in zip(
[gt_color, pred_color], [agent_states, pred_agent_states]
):
for agent_state in agent_state_array:
agent_box = OrientedBox(
StateSE2(*agent_state[BoundingBox2DIndex.STATE_SE2]),
agent_state[BoundingBox2DIndex.LENGTH],
agent_state[BoundingBox2DIndex.WIDTH],
1.0,
)
exterior = np.array(agent_box.geometry.exterior.coords).reshape((-1, 1, 2))
exterior = coords_to_pixel(exterior)
exterior = np.flip(exterior, axis=-1)
cv2.polylines(rgb_map, [exterior], isClosed=True, color=color, thickness=2)
for color, traj in zip(
[gt_color, pred_color], [trajectory, pred_trajectory]
):
trajectory_indices = coords_to_pixel(traj[:,:2])
for x, y in trajectory_indices:
cv2.circle(rgb_map, (y, x), point_size, color, -1) # -1 fills the circle
return rgb_map[::-1, ::-1]