import numpy as np | |
from typing import Tuple | |
import torch | |
class PointShuffle(object): | |
def __init__(self, is_train): | |
self.is_train = is_train | |
def __call__(self, features, targets): | |
if self.is_train: | |
points = features['lidar'] | |
cnt = points.shape[0] | |
idx = torch.randperm(cnt, device=points.device) | |
features['lidar'] = points[idx] | |
return features, targets | |