import torch from torch import Tensor from typing import Optional, Sequence, List def compute_num_placement( valid: Tensor, # [n_agent, n_step] state: Tensor, # [n_agent, n_step] av_id: int, object_id: Tensor, agent_state: List[str], ) -> Tensor: enter_state = agent_state.index('enter') exit_state = agent_state.index('exit') av_index = object_id.tolist().index(av_id) state[av_index] = -1 # we do not incorporate the sdc is_bos = state == enter_state is_eos = state == exit_state num_bos = torch.sum(is_bos, dim=0) num_eos = torch.sum(is_eos, dim=0) return num_bos, num_eos def compute_distance_placement( position: Tensor, state: Tensor, valid: Tensor, av_id: int, object_id: Tensor, agent_state: List[str], ) -> Tensor: enter_state = agent_state.index('enter') exit_state = agent_state.index('exit') av_index = object_id.tolist().index(av_id) state[av_index] = -1 # we do not incorporate the sdc distance = torch.norm(position - position[av_index : av_index + 1], p=2, dim=-1) bos_distance = distance * (state == enter_state) eos_distance = distance * (state == exit_state) return bos_distance, eos_distance