|
import torch |
|
import os |
|
import itertools |
|
import multiprocessing as mp |
|
import torch.nn.functional as F |
|
from pathlib import Path |
|
from torch.nn import CrossEntropyLoss |
|
from torch_scatter import gather_csr |
|
from torch_scatter import segment_csr |
|
from torchmetrics import Metric |
|
from typing import Optional, Tuple, Dict, List |
|
|
|
|
|
__all__ = ['minADE', 'minFDE', 'TokenCls', 'StateAccuracy', 'GridOverlapRate'] |
|
|
|
|
|
class CustomCrossEntropyLoss(CrossEntropyLoss): |
|
|
|
def __init__(self, label_smoothing=0.0, reduction='mean'): |
|
super(CustomCrossEntropyLoss, self).__init__() |
|
self.label_smoothing = label_smoothing |
|
self.reduction = reduction |
|
|
|
def forward(self, input, target): |
|
num_classes = input.size(1) |
|
|
|
log_probs = F.log_softmax(input, dim=1) |
|
|
|
with torch.no_grad(): |
|
smooth_target = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1) |
|
smooth_target = smooth_target * (1 - self.label_smoothing) + self.label_smoothing / num_classes |
|
|
|
loss = -torch.sum(log_probs * smooth_target, dim=1) |
|
|
|
if self.reduction == 'mean': |
|
return loss.mean() |
|
elif self.reduction == 'sum': |
|
return loss.sum() |
|
else: |
|
return loss |
|
|
|
|
|
def topk( |
|
max_guesses: int, |
|
pred: torch.Tensor, |
|
prob: Optional[torch.Tensor] = None, |
|
ptr: Optional[torch.Tensor] = None, |
|
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
max_guesses = min(max_guesses, pred.size(1)) |
|
if max_guesses == pred.size(1): |
|
if prob is not None: |
|
prob = prob / prob.sum(dim=-1, keepdim=True) |
|
else: |
|
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses |
|
return pred, prob |
|
else: |
|
if prob is not None: |
|
if joint: |
|
if ptr is None: |
|
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), |
|
k=max_guesses, dim=-1, largest=True, sorted=True)[1] |
|
inds_topk = inds_topk.repeat(pred.size(0), 1) |
|
else: |
|
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, |
|
reduce='mean'), |
|
k=max_guesses, dim=-1, largest=True, sorted=True)[1] |
|
inds_topk = gather_csr(src=inds_topk, indptr=ptr) |
|
else: |
|
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] |
|
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] |
|
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] |
|
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) |
|
else: |
|
pred_topk = pred[:, :max_guesses] |
|
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses |
|
return pred_topk, prob_topk |
|
|
|
|
|
def topkind( |
|
max_guesses: int, |
|
pred: torch.Tensor, |
|
prob: Optional[torch.Tensor] = None, |
|
ptr: Optional[torch.Tensor] = None, |
|
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
max_guesses = min(max_guesses, pred.size(1)) |
|
if max_guesses == pred.size(1): |
|
if prob is not None: |
|
prob = prob / prob.sum(dim=-1, keepdim=True) |
|
else: |
|
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses |
|
return pred, prob, None |
|
else: |
|
if prob is not None: |
|
if joint: |
|
if ptr is None: |
|
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True), |
|
k=max_guesses, dim=-1, largest=True, sorted=True)[1] |
|
inds_topk = inds_topk.repeat(pred.size(0), 1) |
|
else: |
|
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr, |
|
reduce='mean'), |
|
k=max_guesses, dim=-1, largest=True, sorted=True)[1] |
|
inds_topk = gather_csr(src=inds_topk, indptr=ptr) |
|
else: |
|
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1] |
|
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] |
|
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk] |
|
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True) |
|
else: |
|
pred_topk = pred[:, :max_guesses] |
|
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses |
|
return pred_topk, prob_topk, inds_topk |
|
|
|
|
|
def valid_filter( |
|
pred: torch.Tensor, |
|
target: torch.Tensor, |
|
prob: Optional[torch.Tensor] = None, |
|
valid_mask: Optional[torch.Tensor] = None, |
|
ptr: Optional[torch.Tensor] = None, |
|
keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], |
|
torch.Tensor, torch.Tensor]: |
|
if valid_mask is None: |
|
valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool) |
|
if keep_invalid_final_step: |
|
filter_mask = valid_mask.any(dim=-1) |
|
else: |
|
filter_mask = valid_mask[:, -1] |
|
pred = pred[filter_mask] |
|
target = target[filter_mask] |
|
if prob is not None: |
|
prob = prob[filter_mask] |
|
valid_mask = valid_mask[filter_mask] |
|
if ptr is not None: |
|
num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum') |
|
ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,)) |
|
torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:]) |
|
else: |
|
ptr = target.new_tensor([0, target.size(0)]) |
|
return pred, target, prob, valid_mask, ptr |
|
|
|
|
|
def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6): |
|
""" |
|
|
|
Args: |
|
pred_trajs (batch_size, num_modes, num_timestamps, 7) |
|
pred_scores (batch_size, num_modes): |
|
dist_thresh (float): |
|
num_ret_modes (int, optional): Defaults to 6. |
|
|
|
Returns: |
|
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) |
|
ret_scores (batch_size, num_ret_modes) |
|
ret_idxs (batch_size, num_ret_modes) |
|
""" |
|
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape |
|
pred_goals = pred_trajs[:, :, -1, :] |
|
dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1) |
|
nearby_neighbor = dist < dist_thresh |
|
pred_scores = nearby_neighbor.sum(dim=-1) / num_modes |
|
|
|
sorted_idxs = pred_scores.argsort(dim=-1, descending=True) |
|
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) |
|
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] |
|
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] |
|
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] |
|
|
|
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) |
|
point_cover_mask = (dist < dist_thresh) |
|
|
|
point_val = sorted_pred_scores.clone() |
|
point_val_selected = torch.zeros_like(point_val) |
|
|
|
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() |
|
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) |
|
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) |
|
bs_idxs = torch.arange(batch_size).type_as(ret_idxs) |
|
|
|
for k in range(num_ret_modes): |
|
cur_idx = point_val.argmax(dim=-1) |
|
ret_idxs[:, k] = cur_idx |
|
|
|
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] |
|
point_val = point_val * (~new_cover_mask).float() |
|
point_val_selected[bs_idxs, cur_idx] = -1 |
|
point_val += point_val_selected |
|
|
|
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] |
|
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] |
|
|
|
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) |
|
|
|
ret_idxs = sorted_idxs[bs_idxs, ret_idxs] |
|
return ret_trajs, ret_scores, ret_idxs |
|
|
|
|
|
def batch_nms(pred_trajs, pred_scores, |
|
dist_thresh, num_ret_modes=6, |
|
mode='static', speed=None): |
|
""" |
|
|
|
Args: |
|
pred_trajs (batch_size, num_modes, num_timestamps, 7) |
|
pred_scores (batch_size, num_modes): |
|
dist_thresh (float): |
|
num_ret_modes (int, optional): Defaults to 6. |
|
|
|
Returns: |
|
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) |
|
ret_scores (batch_size, num_ret_modes) |
|
ret_idxs (batch_size, num_ret_modes) |
|
""" |
|
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape |
|
|
|
sorted_idxs = pred_scores.argsort(dim=-1, descending=True) |
|
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) |
|
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] |
|
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] |
|
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] |
|
|
|
if mode == "speed": |
|
scale = torch.ones(batch_size).to(sorted_pred_goals.device) |
|
lon_dist_thresh = 4 * scale |
|
lat_dist_thresh = 0.5 * scale |
|
lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1) |
|
lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1) |
|
point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None]) |
|
else: |
|
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) |
|
point_cover_mask = (dist < dist_thresh) |
|
|
|
point_val = sorted_pred_scores.clone() |
|
point_val_selected = torch.zeros_like(point_val) |
|
|
|
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() |
|
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim) |
|
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes) |
|
bs_idxs = torch.arange(batch_size).type_as(ret_idxs) |
|
|
|
for k in range(num_ret_modes): |
|
cur_idx = point_val.argmax(dim=-1) |
|
ret_idxs[:, k] = cur_idx |
|
|
|
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] |
|
point_val = point_val * (~new_cover_mask).float() |
|
point_val_selected[bs_idxs, cur_idx] = -1 |
|
point_val += point_val_selected |
|
|
|
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx] |
|
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] |
|
|
|
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) |
|
|
|
ret_idxs = sorted_idxs[bs_idxs, ret_idxs] |
|
return ret_trajs, ret_scores, ret_idxs |
|
|
|
|
|
def batch_nms_token(pred_trajs, pred_scores, |
|
dist_thresh, num_ret_modes=6, |
|
mode='static', speed=None): |
|
""" |
|
Args: |
|
pred_trajs (batch_size, num_modes, num_timestamps, 7) |
|
pred_scores (batch_size, num_modes): |
|
dist_thresh (float): |
|
num_ret_modes (int, optional): Defaults to 6. |
|
|
|
Returns: |
|
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5) |
|
ret_scores (batch_size, num_ret_modes) |
|
ret_idxs (batch_size, num_ret_modes) |
|
""" |
|
batch_size, num_modes, num_feat_dim = pred_trajs.shape |
|
|
|
sorted_idxs = pred_scores.argsort(dim=-1, descending=True) |
|
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes) |
|
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs] |
|
sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] |
|
|
|
if mode == "nearby": |
|
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) |
|
values, indices = torch.topk(dist, 5, dim=-1, largest=False) |
|
thresh_hold = values[..., -1] |
|
point_cover_mask = dist < thresh_hold[..., None] |
|
else: |
|
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1) |
|
point_cover_mask = (dist < dist_thresh) |
|
|
|
point_val = sorted_pred_scores.clone() |
|
point_val_selected = torch.zeros_like(point_val) |
|
|
|
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long() |
|
ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim) |
|
ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes) |
|
bs_idxs = torch.arange(batch_size).type_as(ret_idxs) |
|
|
|
for k in range(num_ret_modes): |
|
cur_idx = point_val.argmax(dim=-1) |
|
ret_idxs[:, k] = cur_idx |
|
|
|
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] |
|
point_val = point_val * (~new_cover_mask).float() |
|
point_val_selected[bs_idxs, cur_idx] = -1 |
|
point_val += point_val_selected |
|
|
|
ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx] |
|
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx] |
|
|
|
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes) |
|
|
|
ret_idxs = sorted_idxs[bs_idxs, ret_idxs] |
|
return ret_goals, ret_scores, ret_idxs |
|
|
|
|
|
class TokenCls(Metric): |
|
|
|
def __init__(self, |
|
max_guesses: int = 6, |
|
**kwargs) -> None: |
|
super(TokenCls, self).__init__(**kwargs) |
|
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.max_guesses = max_guesses |
|
|
|
def update(self, |
|
pred: torch.Tensor, |
|
target: torch.Tensor, |
|
valid_mask: Optional[torch.Tensor] = None) -> None: |
|
target = target[..., None] |
|
acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask |
|
self.sum += acc.sum() |
|
self.count += valid_mask.sum() |
|
|
|
def compute(self) -> torch.Tensor: |
|
return self.sum / self.count |
|
|
|
|
|
class minMultiFDE(Metric): |
|
|
|
def __init__(self, |
|
max_guesses: int = 6, |
|
**kwargs) -> None: |
|
super(minMultiFDE, self).__init__(**kwargs) |
|
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.max_guesses = max_guesses |
|
|
|
def update(self, |
|
pred: torch.Tensor, |
|
target: torch.Tensor, |
|
prob: Optional[torch.Tensor] = None, |
|
valid_mask: Optional[torch.Tensor] = None, |
|
keep_invalid_final_step: bool = True) -> None: |
|
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) |
|
pred_topk, _ = topk(self.max_guesses, pred, prob) |
|
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) |
|
self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] - |
|
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), |
|
p=2, dim=-1).min(dim=-1)[0].sum() |
|
self.count += pred.size(0) |
|
|
|
def compute(self) -> torch.Tensor: |
|
return self.sum / self.count |
|
|
|
|
|
class minFDE(Metric): |
|
|
|
def __init__(self, |
|
max_guesses: int = 6, |
|
**kwargs) -> None: |
|
super(minFDE, self).__init__(**kwargs) |
|
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.max_guesses = max_guesses |
|
self.eval_timestep = 70 |
|
|
|
def update(self, |
|
pred: torch.Tensor, |
|
target: torch.Tensor, |
|
prob: Optional[torch.Tensor] = None, |
|
valid_mask: Optional[torch.Tensor] = None, |
|
keep_invalid_final_step: bool = True) -> None: |
|
eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1 |
|
self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) * |
|
valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum() |
|
self.count += valid_mask[:, eval_timestep-1].sum() |
|
|
|
def compute(self) -> torch.Tensor: |
|
return self.sum / self.count |
|
|
|
|
|
class minMultiADE(Metric): |
|
|
|
def __init__(self, |
|
max_guesses: int = 6, |
|
**kwargs) -> None: |
|
super(minMultiADE, self).__init__(**kwargs) |
|
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.max_guesses = max_guesses |
|
|
|
def update(self, |
|
pred: torch.Tensor, |
|
target: torch.Tensor, |
|
prob: Optional[torch.Tensor] = None, |
|
valid_mask: Optional[torch.Tensor] = None, |
|
keep_invalid_final_step: bool = True, |
|
min_criterion: str = 'FDE') -> None: |
|
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step) |
|
pred_topk, _ = topk(self.max_guesses, pred, prob) |
|
if min_criterion == 'FDE': |
|
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1) |
|
inds_best = torch.norm( |
|
pred_topk[torch.arange(pred.size(0)), :, inds_last] - |
|
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1) |
|
self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) * |
|
valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum() |
|
elif min_criterion == 'ADE': |
|
self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) * |
|
valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum() |
|
else: |
|
raise ValueError('{} is not a valid criterion'.format(min_criterion)) |
|
self.count += pred.size(0) |
|
|
|
def compute(self) -> torch.Tensor: |
|
return self.sum / self.count |
|
|
|
|
|
class minADE(Metric): |
|
|
|
def __init__(self, |
|
max_guesses: int = 6, |
|
**kwargs) -> None: |
|
super(minADE, self).__init__(**kwargs) |
|
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.max_guesses = max_guesses |
|
self.eval_timestep = 70 |
|
|
|
def update(self, |
|
pred: torch.Tensor, |
|
target: torch.Tensor, |
|
prob: Optional[torch.Tensor] = None, |
|
valid_mask: Optional[torch.Tensor] = None, |
|
keep_invalid_final_step: bool = True, |
|
min_criterion: str = 'ADE') -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_timestep = min(self.eval_timestep, pred.shape[1]) |
|
self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum() |
|
self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum() |
|
|
|
def compute(self) -> torch.Tensor: |
|
return self.sum / self.count |
|
|
|
|
|
class AverageMeter(Metric): |
|
|
|
def __init__(self, **kwargs) -> None: |
|
super(AverageMeter, self).__init__(**kwargs) |
|
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') |
|
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
|
|
def update(self, val: torch.Tensor) -> None: |
|
self.sum += val.sum() |
|
self.count += val.numel() |
|
|
|
def compute(self) -> torch.Tensor: |
|
return self.sum / self.count |
|
|
|
|
|
class StateAccuracy(Metric): |
|
|
|
def __init__(self, state_token: Dict[str, int], **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.invalid_state = int(state_token['invalid']) |
|
self.valid_state = int(state_token['valid']) |
|
self.enter_state = int(state_token['enter']) |
|
self.exit_state = int(state_token['exit']) |
|
|
|
self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
|
|
def update(self, |
|
state_idx: torch.Tensor, |
|
valid_mask: Optional[torch.Tensor] = None) -> None: |
|
|
|
num_agent, num_step = state_idx.shape |
|
|
|
|
|
for a in range(num_agent): |
|
bos_idx = torch.where(state_idx[a] == self.enter_state)[0] |
|
eos_idx = torch.where(state_idx[a] == self.exit_state)[0] |
|
bos = 0 |
|
eos = num_step - 1 |
|
if len(bos_idx) > 0: |
|
bos = bos_idx[0] |
|
self.invalid += (state_idx[a, :bos] == self.invalid_state).sum() |
|
self.invalid_count += len(state_idx[a, :bos]) |
|
if len(eos_idx) > 0: |
|
eos = eos_idx[0] |
|
self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum() |
|
self.invalid_count += len(state_idx[a, eos + 1:]) |
|
self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum() |
|
self.valid_count += len(state_idx[a, bos + 1 : eos]) |
|
|
|
|
|
if valid_mask is not None: |
|
|
|
state_idx = state_idx.roll(shifts=1, dims=1) |
|
|
|
for a in range(num_agent): |
|
bos_idx = torch.where(state_idx[a] == self.enter_state)[0] |
|
eos_idx = torch.where(state_idx[a] == self.exit_state)[0] |
|
bos = 0 |
|
eos = num_step - 1 |
|
if len(bos_idx) > 0: |
|
bos = bos_idx[0] |
|
self.invalid += (valid_mask[a, :bos] == 0).sum() |
|
self.invalid_count += len(valid_mask[a, :bos]) |
|
if len(eos_idx) > 0: |
|
eos = eos_idx[-1] |
|
self.invalid += (valid_mask[a, eos + 1:] != 0).sum() |
|
self.invalid_count += len(valid_mask[a, eos + 1:]) |
|
self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum() |
|
self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum() |
|
self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum() |
|
self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum() |
|
|
|
def compute(self) -> Dict[str, torch.Tensor]: |
|
return {'valid': self.valid / self.valid_count, |
|
'invalid': self.invalid / self.invalid_count, |
|
} |
|
|
|
def __repr__(self): |
|
head = "Results of " + self.__class__.__name__ |
|
results = self.compute() |
|
body = [ |
|
"valid: {}".format(results['valid']), |
|
"invalid: {}".format(results['invalid']), |
|
] |
|
_repr_indent = 4 |
|
lines = [head] + [" " * _repr_indent + line for line in body] |
|
return "\n".join(lines) |
|
|
|
|
|
class GridOverlapRate(Metric): |
|
|
|
def __init__(self, num_step, state_token, seed_size, **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.num_step = num_step |
|
self.enter_state = int(state_token['enter']) |
|
self.seed_size = seed_size |
|
self.add_state('num_overlap_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') |
|
self.add_state('num_insert_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') |
|
self.add_state('num_total_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') |
|
self.add_state('num_exceed_seed_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum') |
|
|
|
def update(self, |
|
state_token: torch.Tensor, |
|
grid_index: torch.Tensor) -> None: |
|
|
|
for t in range(self.num_step): |
|
inrange_mask_t = grid_index[:, t] != -1 |
|
insert_mask_t = (state_token[:, t] == self.enter_state) & inrange_mask_t |
|
self.num_total_agent_t[t] += inrange_mask_t.sum() |
|
self.num_insert_agent_t[t] += insert_mask_t.sum() |
|
self.num_exceed_seed_t[t] += int(insert_mask_t.sum() >= self.seed_size) |
|
|
|
occupied_grids = set(grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] != self.enter_state)].tolist()) |
|
to_inserted_grids = grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] == self.enter_state)].tolist() |
|
while to_inserted_grids: |
|
grid_index_t_i = to_inserted_grids.pop() |
|
if grid_index_t_i in occupied_grids: |
|
self.num_overlap_t[t] += 1 |
|
occupied_grids.add(grid_index_t_i) |
|
|
|
def compute(self) -> Dict[str, torch.Tensor]: |
|
overlap_rate_t = self.num_overlap_t / self.num_insert_agent_t |
|
overlap_rate_t.nan_to_num_() |
|
return {'num_overlap_t': self.num_overlap_t, |
|
'num_insert_agent_t': self.num_insert_agent_t, |
|
'num_total_agent_t': self.num_total_agent_t, |
|
'overlap_rate_t': overlap_rate_t, |
|
'num_exceed_seed_t': self.num_exceed_seed_t, |
|
} |
|
|
|
def __repr__(self): |
|
head = "Results of " + self.__class__.__name__ |
|
results = self.compute() |
|
body = [ |
|
"num_overlap_t: {}".format(results['num_overlap_t'].tolist()), |
|
"num_insert_agent_t: {}".format(results['num_insert_agent_t'].tolist()), |
|
"num_total_agent_t: {}".format(results['num_total_agent_t'].tolist()), |
|
"overlap_rate_t: {}".format(results['overlap_rate_t'].tolist()), |
|
"num_exceed_seed_t: {}".format(results['num_exceed_seed_t'].tolist()), |
|
] |
|
_repr_indent = 4 |
|
lines = [head] + [" " * _repr_indent + line for line in body] |
|
return "\n".join(lines) |
|
|
|
|
|
class NumInsertAccuracy(Metric): |
|
|
|
def __init__(self, state_token: Dict[str, int], **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.invalid_state = int(state_token['invalid']) |
|
self.valid_state = int(state_token['valid']) |
|
self.enter_state = int(state_token['enter']) |
|
self.exit_state = int(state_token['exit']) |
|
|
|
self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum') |
|
|
|
def update(self, |
|
state_idx: torch.Tensor, |
|
valid_mask: Optional[torch.Tensor] = None) -> None: |
|
|
|
num_agent, num_step = state_idx.shape |
|
|
|
|
|
for a in range(num_agent): |
|
bos_idx = torch.where(state_idx[a] == self.enter_state)[0] |
|
eos_idx = torch.where(state_idx[a] == self.exit_state)[0] |
|
bos = 0 |
|
eos = num_step - 1 |
|
if len(bos_idx) > 0: |
|
bos = bos_idx[0] |
|
self.invalid += (state_idx[a, :bos] == self.invalid_state).sum() |
|
self.invalid_count += len(state_idx[a, :bos]) |
|
if len(eos_idx) > 0: |
|
eos = eos_idx[0] |
|
self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum() |
|
self.invalid_count += len(state_idx[a, eos + 1:]) |
|
self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum() |
|
self.valid_count += len(state_idx[a, bos + 1 : eos]) |
|
|
|
|
|
if valid_mask is not None: |
|
|
|
state_idx = state_idx.roll(shifts=1, dims=1) |
|
|
|
for a in range(num_agent): |
|
bos_idx = torch.where(state_idx[a] == self.enter_state)[0] |
|
eos_idx = torch.where(state_idx[a] == self.exit_state)[0] |
|
bos = 0 |
|
eos = num_step - 1 |
|
if len(bos_idx) > 0: |
|
bos = bos_idx[0] |
|
self.invalid += (valid_mask[a, :bos] == 0).sum() |
|
self.invalid_count += len(valid_mask[a, :bos]) |
|
if len(eos_idx) > 0: |
|
eos = eos_idx[-1] |
|
self.invalid += (valid_mask[a, eos + 1:] != 0).sum() |
|
self.invalid_count += len(valid_mask[a, eos + 1:]) |
|
self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum() |
|
self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum() |
|
self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum() |
|
self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum() |
|
|
|
def compute(self) -> Dict[str, torch.Tensor]: |
|
return {'valid': self.valid / self.valid_count, |
|
'invalid': self.invalid / self.invalid_count, |
|
} |
|
|
|
def __repr__(self): |
|
head = "Results of " + self.__class__.__name__ |
|
results = self.compute() |
|
body = [ |
|
"valid: {}".format(results['valid']), |
|
"invalid: {}".format(results['invalid']), |
|
] |
|
_repr_indent = 4 |
|
lines = [head] + [" " * _repr_indent + line for line in body] |
|
return "\n".join(lines) |
|
|