File size: 1,169 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch.utils.data.dataloader import default_collate

def collate_tensor_fn(batch):
    elem = batch[0]
    out = None
    if torch.utils.data.get_worker_info() is not None:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum(x.numel() for x in batch)
        storage = elem._typed_storage()._new_shared(numel, device=elem.device)
        out = elem.new(storage).resize_(len(batch), *list(elem.size()))
    return torch.stack(batch, 0, out=out)


def collate_fn_pad_lidar(batch):
    feats = dict()
    # skip: 1. collating lidar points
    # skip: 2. collating boxes
    for k in batch[0][0]:
        if k == 'lidar' or k == 'lidars_warped':
            feats[k] = [tmp[0][k] for tmp in batch]
        else:
            feats[k] = collate_tensor_fn([tmp[0][k] for tmp in batch])
    targets = dict()
    # contains gt
    if len(batch[0]) > 1:
        for k in batch[0][1]:
            # targets[k] = collate_tensor_fn([tmp[1][k] for tmp in batch])
            targets[k] = [tmp[1][k] for tmp in batch]
    return feats, targets