navsim_ours / det_map /data /pipelines /point_shuffle.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
452 Bytes
import numpy as np
from typing import Tuple
import torch
class PointShuffle(object):
def __init__(self, is_train):
self.is_train = is_train
def __call__(self, features, targets):
if self.is_train:
points = features['lidar']
cnt = points.shape[0]
idx = torch.randperm(cnt, device=points.device)
features['lidar'] = points[idx]
return features, targets