import torch from torch.utils.data.dataloader import default_collate def collate_tensor_fn(batch): elem = batch[0] out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem._typed_storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) def collate_fn_pad_lidar(batch): feats = dict() # skip: 1. collating lidar points # skip: 2. collating boxes for k in batch[0][0]: if k == 'lidar' or k == 'lidars_warped': feats[k] = [tmp[0][k] for tmp in batch] else: feats[k] = collate_tensor_fn([tmp[0][k] for tmp in batch]) targets = dict() # contains gt if len(batch[0]) > 1: for k in batch[0][1]: # targets[k] = collate_tensor_fn([tmp[1][k] for tmp in batch]) targets[k] = [tmp[1][k] for tmp in batch] return feats, targets