gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
import torch
import os
import itertools
import multiprocessing as mp
import torch.nn.functional as F
from pathlib import Path
from torch.nn import CrossEntropyLoss
from torch_scatter import gather_csr
from torch_scatter import segment_csr
from torchmetrics import Metric
from typing import Optional, Tuple, Dict, List
__all__ = ['minADE', 'minFDE', 'TokenCls', 'StateAccuracy', 'GridOverlapRate']
class CustomCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, label_smoothing=0.0, reduction='mean'):
super(CustomCrossEntropyLoss, self).__init__()
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, input, target):
num_classes = input.size(1)
log_probs = F.log_softmax(input, dim=1)
with torch.no_grad():
smooth_target = torch.zeros_like(log_probs).scatter_(1, target.unsqueeze(1), 1)
smooth_target = smooth_target * (1 - self.label_smoothing) + self.label_smoothing / num_classes
loss = -torch.sum(log_probs * smooth_target, dim=1)
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
def topk(
max_guesses: int,
pred: torch.Tensor,
prob: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
max_guesses = min(max_guesses, pred.size(1))
if max_guesses == pred.size(1):
if prob is not None:
prob = prob / prob.sum(dim=-1, keepdim=True)
else:
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred, prob
else:
if prob is not None:
if joint:
if ptr is None:
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = inds_topk.repeat(pred.size(0), 1)
else:
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
reduce='mean'),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
else:
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
else:
pred_topk = pred[:, :max_guesses]
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred_topk, prob_topk
def topkind(
max_guesses: int,
pred: torch.Tensor,
prob: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
joint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
max_guesses = min(max_guesses, pred.size(1))
if max_guesses == pred.size(1):
if prob is not None:
prob = prob / prob.sum(dim=-1, keepdim=True)
else:
prob = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred, prob, None
else:
if prob is not None:
if joint:
if ptr is None:
inds_topk = torch.topk((prob / prob.sum(dim=-1, keepdim=True)).mean(dim=0, keepdim=True),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = inds_topk.repeat(pred.size(0), 1)
else:
inds_topk = torch.topk(segment_csr(src=prob / prob.sum(dim=-1, keepdim=True), indptr=ptr,
reduce='mean'),
k=max_guesses, dim=-1, largest=True, sorted=True)[1]
inds_topk = gather_csr(src=inds_topk, indptr=ptr)
else:
inds_topk = torch.topk(prob, k=max_guesses, dim=-1, largest=True, sorted=True)[1]
pred_topk = pred[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob[torch.arange(pred.size(0)).unsqueeze(-1).expand(-1, max_guesses), inds_topk]
prob_topk = prob_topk / prob_topk.sum(dim=-1, keepdim=True)
else:
pred_topk = pred[:, :max_guesses]
prob_topk = pred.new_ones((pred.size(0), max_guesses)) / max_guesses
return pred_topk, prob_topk, inds_topk
def valid_filter(
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
ptr: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
torch.Tensor, torch.Tensor]:
if valid_mask is None:
valid_mask = target.new_ones(target.size()[:-1], dtype=torch.bool)
if keep_invalid_final_step:
filter_mask = valid_mask.any(dim=-1)
else:
filter_mask = valid_mask[:, -1]
pred = pred[filter_mask]
target = target[filter_mask]
if prob is not None:
prob = prob[filter_mask]
valid_mask = valid_mask[filter_mask]
if ptr is not None:
num_nodes_batch = segment_csr(src=filter_mask.long(), indptr=ptr, reduce='sum')
ptr = num_nodes_batch.new_zeros((num_nodes_batch.size(0) + 1,))
torch.cumsum(num_nodes_batch, dim=0, out=ptr[1:])
else:
ptr = target.new_tensor([0, target.size(0)])
return pred, target, prob, valid_mask, ptr
def new_batch_nms(pred_trajs, dist_thresh, num_ret_modes=6):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
pred_goals = pred_trajs[:, :, -1, :]
dist = (pred_goals[:, :, None, 0:2] - pred_goals[:, None, :, 0:2]).norm(dim=-1)
nearby_neighbor = dist < dist_thresh
pred_scores = nearby_neighbor.sum(dim=-1) / num_modes
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_trajs, ret_scores, ret_idxs
def batch_nms(pred_trajs, pred_scores,
dist_thresh, num_ret_modes=6,
mode='static', speed=None):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
sorted_pred_goals = sorted_pred_trajs[:, :, -1, :] # (batch_size, num_modes, 7)
if mode == "speed":
scale = torch.ones(batch_size).to(sorted_pred_goals.device)
lon_dist_thresh = 4 * scale
lat_dist_thresh = 0.5 * scale
lon_dist = (sorted_pred_goals[:, :, None, [0]] - sorted_pred_goals[:, None, :, [0]]).norm(dim=-1)
lat_dist = (sorted_pred_goals[:, :, None, [1]] - sorted_pred_goals[:, None, :, [1]]).norm(dim=-1)
point_cover_mask = (lon_dist < lon_dist_thresh[:, None, None]) & (lat_dist < lat_dist_thresh[:, None, None])
else:
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_trajs, ret_scores, ret_idxs
def batch_nms_token(pred_trajs, pred_scores,
dist_thresh, num_ret_modes=6,
mode='static', speed=None):
"""
Args:
pred_trajs (batch_size, num_modes, num_timestamps, 7)
pred_scores (batch_size, num_modes):
dist_thresh (float):
num_ret_modes (int, optional): Defaults to 6.
Returns:
ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
ret_scores (batch_size, num_ret_modes)
ret_idxs (batch_size, num_ret_modes)
"""
batch_size, num_modes, num_feat_dim = pred_trajs.shape
sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]
sorted_pred_goals = pred_trajs[bs_idxs_full, sorted_idxs] # (batch_size, num_modes, num_timestamps, 7)
if mode == "nearby":
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
values, indices = torch.topk(dist, 5, dim=-1, largest=False)
thresh_hold = values[..., -1]
point_cover_mask = dist < thresh_hold[..., None]
else:
dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)
point_cover_mask = (dist < dist_thresh)
point_val = sorted_pred_scores.clone() # (batch_size, N)
point_val_selected = torch.zeros_like(point_val) # (batch_size, N)
ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
ret_goals = sorted_pred_goals.new_zeros(batch_size, num_ret_modes, num_feat_dim)
ret_scores = sorted_pred_goals.new_zeros(batch_size, num_ret_modes)
bs_idxs = torch.arange(batch_size).type_as(ret_idxs)
for k in range(num_ret_modes):
cur_idx = point_val.argmax(dim=-1) # (batch_size)
ret_idxs[:, k] = cur_idx
new_cover_mask = point_cover_mask[bs_idxs, cur_idx] # (batch_size, N)
point_val = point_val * (~new_cover_mask).float() # (batch_size, N)
point_val_selected[bs_idxs, cur_idx] = -1
point_val += point_val_selected
ret_goals[:, k] = sorted_pred_goals[bs_idxs, cur_idx]
ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]
bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)
ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
return ret_goals, ret_scores, ret_idxs
class TokenCls(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(TokenCls, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None) -> None:
target = target[..., None]
acc = (pred[:, :self.max_guesses] == target).any(dim=1) * valid_mask
self.sum += acc.sum()
self.count += valid_mask.sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minMultiFDE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minMultiFDE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> None:
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
pred_topk, _ = topk(self.max_guesses, pred, prob)
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
self.sum += torch.norm(pred_topk[torch.arange(pred.size(0)), :, inds_last] -
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2),
p=2, dim=-1).min(dim=-1)[0].sum()
self.count += pred.size(0)
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minFDE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minFDE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
self.eval_timestep = 70
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True) -> None:
eval_timestep = min(self.eval_timestep, pred.shape[1]) - 1
self.sum += ((torch.norm(pred[:, eval_timestep-1:eval_timestep] - target[:, eval_timestep-1:eval_timestep], p=2, dim=-1) *
valid_mask[:, eval_timestep-1].unsqueeze(1)).sum(dim=-1)).sum()
self.count += valid_mask[:, eval_timestep-1].sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minMultiADE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minMultiADE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True,
min_criterion: str = 'FDE') -> None:
pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
pred_topk, _ = topk(self.max_guesses, pred, prob)
if min_criterion == 'FDE':
inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
inds_best = torch.norm(
pred_topk[torch.arange(pred.size(0)), :, inds_last] -
target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
self.sum += ((torch.norm(pred_topk[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
elif min_criterion == 'ADE':
self.sum += ((torch.norm(pred_topk - target.unsqueeze(1), p=2, dim=-1) *
valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
else:
raise ValueError('{} is not a valid criterion'.format(min_criterion))
self.count += pred.size(0)
def compute(self) -> torch.Tensor:
return self.sum / self.count
class minADE(Metric):
def __init__(self,
max_guesses: int = 6,
**kwargs) -> None:
super(minADE, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
self.max_guesses = max_guesses
self.eval_timestep = 70
def update(self,
pred: torch.Tensor,
target: torch.Tensor,
prob: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
keep_invalid_final_step: bool = True,
min_criterion: str = 'ADE') -> None:
# pred, target, prob, valid_mask, _ = valid_filter(pred, target, prob, valid_mask, None, keep_invalid_final_step)
# pred_topk, _ = topk(self.max_guesses, pred, prob)
# if min_criterion == 'FDE':
# inds_last = (valid_mask * torch.arange(1, valid_mask.size(-1) + 1, device=self.device)).argmax(dim=-1)
# inds_best = torch.norm(
# pred[torch.arange(pred.size(0)), :, inds_last] -
# target[torch.arange(pred.size(0)), inds_last].unsqueeze(-2), p=2, dim=-1).argmin(dim=-1)
# self.sum += ((torch.norm(pred[torch.arange(pred.size(0)), inds_best] - target, p=2, dim=-1) *
# valid_mask).sum(dim=-1) / valid_mask.sum(dim=-1)).sum()
# elif min_criterion == 'ADE':
# self.sum += ((torch.norm(pred - target.unsqueeze(1), p=2, dim=-1) *
# valid_mask.unsqueeze(1)).sum(dim=-1).min(dim=-1)[0] / valid_mask.sum(dim=-1)).sum()
# else:
# raise ValueError('{} is not a valid criterion'.format(min_criterion))
eval_timestep = min(self.eval_timestep, pred.shape[1])
self.sum += ((torch.norm(pred[:, :eval_timestep] - target[:, :eval_timestep], p=2, dim=-1) * valid_mask[:, :eval_timestep]).sum(dim=-1) / pred.shape[1]).sum()
self.count += valid_mask[:, :eval_timestep].any(dim=-1).sum()
def compute(self) -> torch.Tensor:
return self.sum / self.count
class AverageMeter(Metric):
def __init__(self, **kwargs) -> None:
super(AverageMeter, self).__init__(**kwargs)
self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
def update(self, val: torch.Tensor) -> None:
self.sum += val.sum()
self.count += val.numel()
def compute(self) -> torch.Tensor:
return self.sum / self.count
class StateAccuracy(Metric):
def __init__(self, state_token: Dict[str, int], **kwargs) -> None:
super().__init__(**kwargs)
self.invalid_state = int(state_token['invalid'])
self.valid_state = int(state_token['valid'])
self.enter_state = int(state_token['enter'])
self.exit_state = int(state_token['exit'])
self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum')
def update(self,
state_idx: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None) -> None:
num_agent, num_step = state_idx.shape
# check the evaluation outputs
for a in range(num_agent):
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
bos = 0
eos = num_step - 1
if len(bos_idx) > 0:
bos = bos_idx[0]
self.invalid += (state_idx[a, :bos] == self.invalid_state).sum()
self.invalid_count += len(state_idx[a, :bos])
if len(eos_idx) > 0:
eos = eos_idx[0]
self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum()
self.invalid_count += len(state_idx[a, eos + 1:])
self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum()
self.valid_count += len(state_idx[a, bos + 1 : eos])
# check the tokenization
if valid_mask is not None:
state_idx = state_idx.roll(shifts=1, dims=1)
for a in range(num_agent):
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
bos = 0
eos = num_step - 1
if len(bos_idx) > 0:
bos = bos_idx[0]
self.invalid += (valid_mask[a, :bos] == 0).sum()
self.invalid_count += len(valid_mask[a, :bos])
if len(eos_idx) > 0:
eos = eos_idx[-1]
self.invalid += (valid_mask[a, eos + 1:] != 0).sum()
self.invalid_count += len(valid_mask[a, eos + 1:])
self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum()
self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum()
self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum()
self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum()
def compute(self) -> Dict[str, torch.Tensor]:
return {'valid': self.valid / self.valid_count,
'invalid': self.invalid / self.invalid_count,
}
def __repr__(self):
head = "Results of " + self.__class__.__name__
results = self.compute()
body = [
"valid: {}".format(results['valid']),
"invalid: {}".format(results['invalid']),
]
_repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)
class GridOverlapRate(Metric):
def __init__(self, num_step, state_token, seed_size, **kwargs) -> None:
super().__init__(**kwargs)
self.num_step = num_step
self.enter_state = int(state_token['enter'])
self.seed_size = seed_size
self.add_state('num_overlap_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
self.add_state('num_insert_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
self.add_state('num_total_agent_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
self.add_state('num_exceed_seed_t', default=torch.zeros(num_step).long(), dist_reduce_fx='sum')
def update(self,
state_token: torch.Tensor,
grid_index: torch.Tensor) -> None:
for t in range(self.num_step):
inrange_mask_t = grid_index[:, t] != -1
insert_mask_t = (state_token[:, t] == self.enter_state) & inrange_mask_t
self.num_total_agent_t[t] += inrange_mask_t.sum()
self.num_insert_agent_t[t] += insert_mask_t.sum()
self.num_exceed_seed_t[t] += int(insert_mask_t.sum() >= self.seed_size)
occupied_grids = set(grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] != self.enter_state)].tolist())
to_inserted_grids = grid_index[:, t][(grid_index[:, t] != -1) & (state_token[:, t] == self.enter_state)].tolist()
while to_inserted_grids:
grid_index_t_i = to_inserted_grids.pop()
if grid_index_t_i in occupied_grids:
self.num_overlap_t[t] += 1
occupied_grids.add(grid_index_t_i)
def compute(self) -> Dict[str, torch.Tensor]:
overlap_rate_t = self.num_overlap_t / self.num_insert_agent_t
overlap_rate_t.nan_to_num_()
return {'num_overlap_t': self.num_overlap_t,
'num_insert_agent_t': self.num_insert_agent_t,
'num_total_agent_t': self.num_total_agent_t,
'overlap_rate_t': overlap_rate_t,
'num_exceed_seed_t': self.num_exceed_seed_t,
}
def __repr__(self):
head = "Results of " + self.__class__.__name__
results = self.compute()
body = [
"num_overlap_t: {}".format(results['num_overlap_t'].tolist()),
"num_insert_agent_t: {}".format(results['num_insert_agent_t'].tolist()),
"num_total_agent_t: {}".format(results['num_total_agent_t'].tolist()),
"overlap_rate_t: {}".format(results['overlap_rate_t'].tolist()),
"num_exceed_seed_t: {}".format(results['num_exceed_seed_t'].tolist()),
]
_repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)
class NumInsertAccuracy(Metric):
def __init__(self, state_token: Dict[str, int], **kwargs) -> None:
super().__init__(**kwargs)
self.invalid_state = int(state_token['invalid'])
self.valid_state = int(state_token['valid'])
self.enter_state = int(state_token['enter'])
self.exit_state = int(state_token['exit'])
self.add_state('valid', default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state('valid_count', default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state('invalid', default=torch.tensor(0), dist_reduce_fx='sum')
self.add_state('invalid_count', default=torch.tensor(0), dist_reduce_fx='sum')
def update(self,
state_idx: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None) -> None:
num_agent, num_step = state_idx.shape
# check the evaluation outputs
for a in range(num_agent):
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
bos = 0
eos = num_step - 1
if len(bos_idx) > 0:
bos = bos_idx[0]
self.invalid += (state_idx[a, :bos] == self.invalid_state).sum()
self.invalid_count += len(state_idx[a, :bos])
if len(eos_idx) > 0:
eos = eos_idx[0]
self.invalid += (state_idx[a, eos + 1:] == self.invalid_state).sum()
self.invalid_count += len(state_idx[a, eos + 1:])
self.valid += (state_idx[a, bos + 1 : eos] == self.valid_state).sum()
self.valid_count += len(state_idx[a, bos + 1 : eos])
# check the tokenization
if valid_mask is not None:
state_idx = state_idx.roll(shifts=1, dims=1)
for a in range(num_agent):
bos_idx = torch.where(state_idx[a] == self.enter_state)[0]
eos_idx = torch.where(state_idx[a] == self.exit_state)[0]
bos = 0
eos = num_step - 1
if len(bos_idx) > 0:
bos = bos_idx[0]
self.invalid += (valid_mask[a, :bos] == 0).sum()
self.invalid_count += len(valid_mask[a, :bos])
if len(eos_idx) > 0:
eos = eos_idx[-1]
self.invalid += (valid_mask[a, eos + 1:] != 0).sum()
self.invalid_count += len(valid_mask[a, eos + 1:])
self.invalid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 0]).sum()
self.invalid_count += (valid_mask[a, bos : eos + 1] == 0).sum()
self.valid += (((state_idx[a, bos : eos + 1] > 0) != valid_mask[a, bos : eos + 1])[valid_mask[a, bos : eos + 1] == 1]).sum()
self.valid_count += (valid_mask[a, bos : eos + 1] == 1).sum()
def compute(self) -> Dict[str, torch.Tensor]:
return {'valid': self.valid / self.valid_count,
'invalid': self.invalid / self.invalid_count,
}
def __repr__(self):
head = "Results of " + self.__class__.__name__
results = self.compute()
body = [
"valid: {}".format(results['valid']),
"invalid: {}".format(results['invalid']),
]
_repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)