longsim-base / backups /dev /metrics /placement_features.py
gzzyyxy's picture
Upload folder using huggingface_hub
c1a7f73 verified
import torch
from torch import Tensor
from typing import Optional, Sequence, List
def compute_num_placement(
valid: Tensor, # [n_agent, n_step]
state: Tensor, # [n_agent, n_step]
av_id: int,
object_id: Tensor,
agent_state: List[str],
) -> Tensor:
enter_state = agent_state.index('enter')
exit_state = agent_state.index('exit')
av_index = object_id.tolist().index(av_id)
state[av_index] = -1 # we do not incorporate the sdc
is_bos = state == enter_state
is_eos = state == exit_state
num_bos = torch.sum(is_bos, dim=0)
num_eos = torch.sum(is_eos, dim=0)
return num_bos, num_eos
def compute_distance_placement(
position: Tensor,
state: Tensor,
valid: Tensor,
av_id: int,
object_id: Tensor,
agent_state: List[str],
) -> Tensor:
enter_state = agent_state.index('enter')
exit_state = agent_state.index('exit')
av_index = object_id.tolist().index(av_id)
state[av_index] = -1 # we do not incorporate the sdc
distance = torch.norm(position - position[av_index : av_index + 1], p=2, dim=-1)
bos_distance = distance * (state == enter_state)
eos_distance = distance * (state == exit_state)
return bos_distance, eos_distance