# Copyright (c) Meta Platforms, Inc. and affiliates. import collections import os import torch from torch.utils.data import get_worker_info from torch.utils.data._utils.collate import ( default_collate_err_msg_format, np_str_obj_array_pattern, ) from lightning_fabric.utilities.seed import pl_worker_init_function def collate(batch): """Difference with PyTorch default_collate: it can stack other tensor-like objects. Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich https://github.com/cvg/pixloc Released under the Apache License 2.0 """ if not isinstance(batch, list): # no batching return batch # Filter None Elements batch = [elem for elem in batch if elem is not None] elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum(x.numel() for x in batch) storage = elem.storage()._new_shared(numel, device=elem.device) out = elem.new(storage).resize_(len(batch), *list(elem.size())) return torch.stack(batch, 0, out=out) elif ( elem_type.__module__ == "numpy" and elem_type.__name__ != "str_" and elem_type.__name__ != "string_" ): if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap": # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, (str, bytes)): return batch elif isinstance(elem, collections.abc.Mapping): return {key: collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple return elem_type(*(collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") transposed = zip(*batch) return [collate(samples) for samples in transposed] else: # try to stack anyway in case the object implements stacking. try: return torch.stack(batch, 0) except TypeError as e: if "expected Tensor as element" in str(e): return batch else: raise e def set_num_threads(nt): """Force numpy and other libraries to use a limited number of threads.""" try: import mkl except ImportError: pass else: mkl.set_num_threads(nt) torch.set_num_threads(1) os.environ["IPC_ENABLE"] = "1" for o in [ "OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS", "OMP_NUM_THREADS", "MKL_NUM_THREADS", ]: os.environ[o] = str(nt) def worker_init_fn(i): info = get_worker_info() pl_worker_init_function(info.id) num_threads = info.dataset.cfg.get("num_threads") if num_threads is not None: set_num_threads(num_threads)