import torch | |
from torch.utils.data.dataloader import default_collate | |
def collate_tensor_fn(batch): | |
elem = batch[0] | |
out = None | |
if torch.utils.data.get_worker_info() is not None: | |
# If we're in a background process, concatenate directly into a | |
# shared memory tensor to avoid an extra copy | |
numel = sum(x.numel() for x in batch) | |
storage = elem._typed_storage()._new_shared(numel, device=elem.device) | |
out = elem.new(storage).resize_(len(batch), *list(elem.size())) | |
return torch.stack(batch, 0, out=out) | |
def collate_fn_pad_lidar(batch): | |
feats = dict() | |
# skip: 1. collating lidar points | |
# skip: 2. collating boxes | |
for k in batch[0][0]: | |
if k == 'lidar' or k == 'lidars_warped': | |
feats[k] = [tmp[0][k] for tmp in batch] | |
else: | |
feats[k] = collate_tensor_fn([tmp[0][k] for tmp in batch]) | |
targets = dict() | |
# contains gt | |
if len(batch[0]) > 1: | |
for k in batch[0][1]: | |
# targets[k] = collate_tensor_fn([tmp[1][k] for tmp in batch]) | |
targets[k] = [tmp[1][k] for tmp in batch] | |
return feats, targets |