import os |
import random |
import numpy as np |
from PIL import Image |
from loguru import logger |
import sys |
import inspect |
import torch |
from torch import nn |
import torch.distributed as dist |
def init_random_seed(seed=None, device='cuda', rank=0, world_size=1): |
"""Initialize random seed.""" |
if seed is not None: |
return seed |
seed = np.random.randint(2**31) |
if world_size == 1: |
return seed |
if rank == 0: |
random_num = torch.tensor(seed, dtype=torch.int32, device=device) |
else: |
random_num = torch.tensor(0, dtype=torch.int32, device=device) |
dist.broadcast(random_num, src=0) |
return random_num.item() |
def set_random_seed(seed, deterministic=False): |
"""Set random seed.""" |
random.seed(seed) |
np.random.seed(seed) |
torch.manual_seed(seed) |
torch.cuda.manual_seed_all(seed) |
if deterministic: |
torch.backends.cudnn.deterministic = True |
torch.backends.cudnn.benchmark = False |
@torch.no_grad() |
def concat_all_gather(tensor): |
""" |
Performs all_gather operation on the provided tensors. |
*** Warning ***: torch.distributed.all_gather has no gradient. |
""" |
tensor = tensor.contiguous() |
tensors_gather = [ |
torch.ones_like(tensor) |
for _ in range(torch.distributed.get_world_size()) |
] |
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
output = torch.cat(tensors_gather, dim=0) |
return output |
def worker_init_fn(worker_id, num_workers, rank, seed): |
worker_seed = num_workers * rank + worker_id + seed |
np.random.seed(worker_seed) |
random.seed(worker_seed) |
class AverageMeter(object): |
"""Computes and stores the average and current value""" |
def __init__(self, name, fmt=":f"): |
self.name = name |
self.fmt = fmt |
self.reset() |
def reset(self): |
self.val = 0 |
self.avg = 0 |
self.sum = 0 |
self.count = 0 |
def update(self, val, n=1): |
self.val = val |
self.sum += val * n |
self.count += n |
self.avg = self.sum / self.count |
def __str__(self): |
if self.name == "Lr": |
fmtstr = "{name}={val" + self.fmt + "}" |
else: |
fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})" |
return fmtstr.format(**self.__dict__) |
class ProgressMeter(object): |
def __init__(self, num_batches, meters, prefix=""): |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
self.meters = meters |
self.prefix = prefix |
def display(self, batch): |
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
entries += [str(meter) for meter in self.meters] |
logger.info(" ".join(entries)) |
def _get_batch_fmtstr(self, num_batches): |
num_digits = len(str(num_batches // 1)) |
fmt = "{:" + str(num_digits) + "d}" |
return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5): |
assert (output.dim() in [2, 3, 4]) |
assert output.shape == target.shape |
output = output.flatten(1) |
target = target.flatten(1) |
output = torch.sigmoid(output) |
output[output < threshold] = 0. |
output[output >= threshold] = 1. |
inter = (output.bool() & target.bool()).sum(dim=1) |
union = (output.bool() | target.bool()).sum(dim=1) |
ious = inter / (union + 1e-6) |
iou = ious.mean() |
prec = (ious > pr_iou).float().mean() |
return 100. * iou, 100. * prec |
def ValMetricGPU(output, target, threshold=0.35): |
assert output.size(0) == 1 |
output = output.flatten(1) |
target = target.flatten(1) |
output = torch.sigmoid(output) |
output[output < threshold] = 0. |
output[output >= threshold] = 1. |
inter = (output.bool() & target.bool()).sum(dim=1) |
union = (output.bool() | target.bool()).sum(dim=1) |
ious = inter / (union + 1e-6) |
return ious |
def intersectionAndUnionGPU(output, target, K, threshold=0.5): |
assert (output.dim() in [1, 2, 3]) |
assert output.shape == target.shape |
output = output.view(-1) |
target = target.view(-1) |
output = torch.sigmoid(output) |
output[output < threshold] = 0. |
output[output >= threshold] = 1. |
intersection = output[output == target] |
area_intersection = torch.histc(intersection.float(), |
bins=K, |
min=0, |
max=K - 1) |
area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1) |
area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1) |
area_union = area_output + area_target - area_intersection |
return area_intersection[1], area_union[1] |
def group_weight(weight_group, module, lr): |
group_decay = [] |
group_no_decay = [] |
for m in module.modules(): |
if isinstance(m, nn.Linear): |
group_decay.append(m.weight) |
if m.bias is not None: |
group_no_decay.append(m.bias) |
elif isinstance(m, nn.modules.conv._ConvNd): |
group_decay.append(m.weight) |
if m.bias is not None: |
group_no_decay.append(m.bias) |
elif isinstance(m, nn.modules.batchnorm._BatchNorm): |
if m.weight is not None: |
group_no_decay.append(m.weight) |
if m.bias is not None: |
group_no_decay.append(m.bias) |
assert len(list( |
module.parameters())) == len(group_decay) + len(group_no_decay) |
weight_group.append(dict(params=group_decay, lr=lr)) |
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) |
return weight_group |
def colorize(gray, palette): |
color = Image.fromarray(gray.astype(np.uint8)).convert('P') |
color.putpalette(palette) |
return color |
def find_free_port(): |
import socket |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
sock.bind(("", 0)) |
port = sock.getsockname()[1] |
sock.close() |
return port |
def get_caller_name(depth=0): |
""" |
Args: |
depth (int): Depth of caller conext, use 0 for caller depth. |
Default value: 0. |
Returns: |
str: module name of the caller |
""" |
frame = inspect.currentframe().f_back |
for _ in range(depth): |
frame = frame.f_back |
return frame.f_globals["__name__"] |
class StreamToLoguru: |
""" |
stream object that redirects writes to a logger instance. |
""" |
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): |
""" |
Args: |
level(str): log level string of loguru. Default value: "INFO". |
caller_names(tuple): caller names of redirected module. |
Default value: (apex, pycocotools). |
""" |
self.level = level |
self.linebuf = "" |
self.caller_names = caller_names |
def write(self, buf): |
full_name = get_caller_name(depth=1) |
module_name = full_name.rsplit(".", maxsplit=-1)[0] |
if module_name in self.caller_names: |
for line in buf.rstrip().splitlines(): |
logger.opt(depth=2).log(self.level, line.rstrip()) |
else: |
sys.__stdout__.write(buf) |
def flush(self): |
pass |
def redirect_sys_output(log_level="INFO"): |
redirect_logger = StreamToLoguru(log_level) |
sys.stderr = redirect_logger |
sys.stdout = redirect_logger |
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): |
"""setup logger for training and testing. |
Args: |
save_dir(str): location to save log file |
distributed_rank(int): device rank when multi-gpu environment |
filename (string): log save name. |
mode(str): log file write mode, `append` or `override`. default is `a`. |
Return: |
logger instance. |
""" |
loguru_format = ( |
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " |
"<level>{level: <8}</level> | " |
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>") |
logger.remove() |
save_file = os.path.join(save_dir, filename) |
if mode == "o" and os.path.exists(save_file): |
os.remove(save_file) |
if distributed_rank == 0: |
logger.add( |
sys.stderr, |
format=loguru_format, |
level="INFO", |
enqueue=True, |
) |
logger.add(save_file) |
redirect_sys_output("INFO") |