navsim_ours / det_map /utils.py
lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
1.17 kB
import torch
from torch.utils.data.dataloader import default_collate
def collate_tensor_fn(batch):
elem = batch[0]
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
def collate_fn_pad_lidar(batch):
feats = dict()
# skip: 1. collating lidar points
# skip: 2. collating boxes
for k in batch[0][0]:
if k == 'lidar' or k == 'lidars_warped':
feats[k] = [tmp[0][k] for tmp in batch]
else:
feats[k] = collate_tensor_fn([tmp[0][k] for tmp in batch])
targets = dict()
# contains gt
if len(batch[0]) > 1:
for k in batch[0][1]:
# targets[k] = collate_tensor_fn([tmp[1][k] for tmp in batch])
targets[k] = [tmp[1][k] for tmp in batch]
return feats, targets