import os import random import numpy as np from PIL import Image from loguru import logger import sys import inspect import torch from torch import nn import torch.distributed as dist def init_random_seed(seed=None, device='cuda', rank=0, world_size=1): """Initialize random seed.""" if seed is not None: return seed # Make sure all ranks share the same random seed to prevent # some potential bugs. Please refer to # https://github.com/open-mmlab/mmdetection/issues/6339 seed = np.random.randint(2**31) if world_size == 1: return seed if rank == 0: random_num = torch.tensor(seed, dtype=torch.int32, device=device) else: random_num = torch.tensor(0, dtype=torch.int32, device=device) dist.broadcast(random_num, src=0) return random_num.item() def set_random_seed(seed, deterministic=False): """Set random seed.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensor = tensor.contiguous() tensors_gather = [ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def worker_init_fn(worker_id, num_workers, rank, seed): # The seed of each worker equals to # num_worker * rank + worker_id + user_seed worker_seed = num_workers * rank + worker_id + seed np.random.seed(worker_seed) random.seed(worker_seed) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): if self.name == "Lr": fmtstr = "{name}={val" + self.fmt + "}" else: fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] logger.info(" ".join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = "{:" + str(num_digits) + "d}" return "[" + fmt + "/" + fmt.format(num_batches) + "]" def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5): assert (output.dim() in [2, 3, 4]) assert output.shape == target.shape output = output.flatten(1) target = target.flatten(1) output = torch.sigmoid(output) output[output < threshold] = 0. output[output >= threshold] = 1. # inter & union inter = (output.bool() & target.bool()).sum(dim=1) # b union = (output.bool() | target.bool()).sum(dim=1) # b ious = inter / (union + 1e-6) # 0 ~ 1 # iou & pr@5 iou = ious.mean() prec = (ious > pr_iou).float().mean() return 100. * iou, 100. * prec def ValMetricGPU(output, target, threshold=0.35): assert output.size(0) == 1 output = output.flatten(1) target = target.flatten(1) output = torch.sigmoid(output) output[output < threshold] = 0. output[output >= threshold] = 1. # inter & union inter = (output.bool() & target.bool()).sum(dim=1) # b union = (output.bool() | target.bool()).sum(dim=1) # b ious = inter / (union + 1e-6) # 0 ~ 1 return ious def intersectionAndUnionGPU(output, target, K, threshold=0.5): # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. assert (output.dim() in [1, 2, 3]) assert output.shape == target.shape output = output.view(-1) target = target.view(-1) output = torch.sigmoid(output) output[output < threshold] = 0. output[output >= threshold] = 1. intersection = output[output == target] area_intersection = torch.histc(intersection.float(), bins=K, min=0, max=K - 1) area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1) area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1) area_union = area_output + area_target - area_intersection return area_intersection[1], area_union[1] def group_weight(weight_group, module, lr): group_decay = [] group_no_decay = [] for m in module.modules(): if isinstance(m, nn.Linear): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.conv._ConvNd): group_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) elif isinstance(m, nn.modules.batchnorm._BatchNorm): if m.weight is not None: group_no_decay.append(m.weight) if m.bias is not None: group_no_decay.append(m.bias) assert len(list( module.parameters())) == len(group_decay) + len(group_no_decay) weight_group.append(dict(params=group_decay, lr=lr)) weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) return weight_group def colorize(gray, palette): # gray: numpy array of the label and 1*3N size list palette color = Image.fromarray(gray.astype(np.uint8)).convert('P') color.putpalette(palette) return color def find_free_port(): import socket sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Binding to port 0 will cause the OS to find an available port for us sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() # NOTE: there is still a chance the port could be taken by other processes. return port def get_caller_name(depth=0): """ Args: depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0. Returns: str: module name of the caller """ # the following logic is a little bit faster than inspect.stack() logic frame = inspect.currentframe().f_back for _ in range(depth): frame = frame.f_back return frame.f_globals["__name__"] class StreamToLoguru: """ stream object that redirects writes to a logger instance. """ def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): """ Args: level(str): log level string of loguru. Default value: "INFO". caller_names(tuple): caller names of redirected module. Default value: (apex, pycocotools). """ self.level = level self.linebuf = "" self.caller_names = caller_names def write(self, buf): full_name = get_caller_name(depth=1) module_name = full_name.rsplit(".", maxsplit=-1)[0] if module_name in self.caller_names: for line in buf.rstrip().splitlines(): # use caller level log logger.opt(depth=2).log(self.level, line.rstrip()) else: sys.__stdout__.write(buf) def flush(self): pass def redirect_sys_output(log_level="INFO"): redirect_logger = StreamToLoguru(log_level) sys.stderr = redirect_logger sys.stdout = redirect_logger def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): """setup logger for training and testing. Args: save_dir(str): location to save log file distributed_rank(int): device rank when multi-gpu environment filename (string): log save name. mode(str): log file write mode, `append` or `override`. default is `a`. Return: logger instance. """ loguru_format = ( "{time:YYYY-MM-DD HH:mm:ss} | " "{level: <8} | " "{name}:{line} - {message}") logger.remove() save_file = os.path.join(save_dir, filename) if mode == "o" and os.path.exists(save_file): os.remove(save_file) # only keep logger in rank0 process if distributed_rank == 0: logger.add( sys.stderr, format=loguru_format, level="INFO", enqueue=True, ) logger.add(save_file) # redirect stdout/stderr to loguru redirect_sys_output("INFO")