File size: 1,169 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from torch.utils.data.dataloader import default_collate
def collate_tensor_fn(batch):
elem = batch[0]
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
def collate_fn_pad_lidar(batch):
feats = dict()
# skip: 1. collating lidar points
# skip: 2. collating boxes
for k in batch[0][0]:
if k == 'lidar' or k == 'lidars_warped':
feats[k] = [tmp[0][k] for tmp in batch]
else:
feats[k] = collate_tensor_fn([tmp[0][k] for tmp in batch])
targets = dict()
# contains gt
if len(batch[0]) > 1:
for k in batch[0][1]:
# targets[k] = collate_tensor_fn([tmp[1][k] for tmp in batch])
targets[k] = [tmp[1][k] for tmp in batch]
return feats, targets |