|
import os |
|
import pytorch_lightning as pl |
|
from argparse import ArgumentParser |
|
from rich.console import Console |
|
from torch_geometric.loader import DataLoader |
|
|
|
from dev.datasets.scalable_dataset import MultiDataset, WaymoTargetBuilder |
|
from dev.utils.func import load_config_act, Logging |
|
from dev.model.smart import SMART |
|
|
|
CONSOLE = Console(width=120) |
|
|
|
|
|
if __name__ == '__main__': |
|
pl.seed_everything(2024, workers=True) |
|
parser = ArgumentParser() |
|
parser.add_argument('--seed', type=int, default=2024) |
|
parser.add_argument('--config', type=str, default='configs/train/train_scalable_with_state.yaml') |
|
parser.add_argument('--ckpt_path', type=str, default="") |
|
parser.add_argument('--insert_agent', action='store_true') |
|
parser.add_argument('--t', type=str, default=2) |
|
parser.add_argument('--save_path', type=str, default=None) |
|
args = parser.parse_args() |
|
pl.seed_everything(args.seed, workers=True) |
|
config = load_config_act(args.config) |
|
logger = Logging().log(level='DEBUG') |
|
|
|
data_config = config.Dataset |
|
val_dataset = MultiDataset(split='val', |
|
raw_dir=data_config.val_raw_dir, |
|
token_size=data_config.token_size, |
|
transform=WaymoTargetBuilder( |
|
config.Model.num_historical_steps, |
|
config.Model.decoder.num_future_steps, |
|
max_num=data_config.max_num, |
|
training=False), |
|
tfrecord_dir=data_config.val_tfrecords_splitted, |
|
predict_motion=config.Model.predict_motion, |
|
predict_state=config.Model.predict_state, |
|
predict_map=config.Model.predict_map, |
|
buffer_size=config.Model.buffer_size, |
|
logger=logger, |
|
) |
|
dataloader = DataLoader(val_dataset, |
|
shuffle=False, |
|
num_workers=data_config.num_workers, |
|
pin_memory=data_config.pin_memory, |
|
persistent_workers=True if data_config.num_workers > 0 else False |
|
) |
|
|
|
if args.save_path is not None: |
|
save_path = args.save_path |
|
else: |
|
assert args.ckpt_path != "" and os.path.exists(args.ckpt_path), f"Path {args.ckpt_path} not exist!" |
|
save_path = os.path.join(os.path.dirname(args.ckpt_path), 'val') |
|
CONSOLE.log(f"Results will be saved to [yellow]{save_path}[/]") |
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
model = SMART(config.Model, save_path=save_path, logger=logger, insert_agent=args.insert_agent, t=args.t) |
|
CONSOLE.log(f"Loaded model from [yellow]{args.ckpt_path}[/]") |
|
|
|
trainer_config = config.Trainer |
|
trainer = pl.Trainer(accelerator=trainer_config.accelerator, |
|
devices=trainer_config.devices, |
|
strategy='ddp', num_sanity_val_steps=0) |
|
trainer.validate(model, dataloader, ckpt_path=args.ckpt_path) |
|
|
|
CONSOLE.log(f"Validation done!") |
|
|