|
import os |
|
import random |
|
import numpy as np |
|
from PIL import Image |
|
from loguru import logger |
|
import sys |
|
import inspect |
|
|
|
import torch |
|
from torch import nn |
|
import torch.distributed as dist |
|
|
|
|
|
def init_random_seed(seed=None, device='cuda', rank=0, world_size=1): |
|
"""Initialize random seed.""" |
|
if seed is not None: |
|
return seed |
|
|
|
|
|
|
|
|
|
seed = np.random.randint(2**31) |
|
if world_size == 1: |
|
return seed |
|
|
|
if rank == 0: |
|
random_num = torch.tensor(seed, dtype=torch.int32, device=device) |
|
else: |
|
random_num = torch.tensor(0, dtype=torch.int32, device=device) |
|
dist.broadcast(random_num, src=0) |
|
return random_num.item() |
|
|
|
|
|
def set_random_seed(seed, deterministic=False): |
|
"""Set random seed.""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
if deterministic: |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
@torch.no_grad() |
|
def concat_all_gather(tensor): |
|
""" |
|
Performs all_gather operation on the provided tensors. |
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
|
""" |
|
tensor = tensor.contiguous() |
|
tensors_gather = [ |
|
torch.ones_like(tensor) |
|
for _ in range(torch.distributed.get_world_size()) |
|
] |
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|
|
|
output = torch.cat(tensors_gather, dim=0) |
|
return output |
|
|
|
|
|
def worker_init_fn(worker_id, num_workers, rank, seed): |
|
|
|
|
|
worker_seed = num_workers * rank + worker_id + seed |
|
np.random.seed(worker_seed) |
|
random.seed(worker_seed) |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, name, fmt=":f"): |
|
self.name = name |
|
self.fmt = fmt |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
if self.name == "Lr": |
|
fmtstr = "{name}={val" + self.fmt + "}" |
|
else: |
|
fmtstr = "{name}={val" + self.fmt + "} ({avg" + self.fmt + "})" |
|
return fmtstr.format(**self.__dict__) |
|
|
|
|
|
class ProgressMeter(object): |
|
def __init__(self, num_batches, meters, prefix=""): |
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
|
self.meters = meters |
|
self.prefix = prefix |
|
|
|
def display(self, batch): |
|
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
|
entries += [str(meter) for meter in self.meters] |
|
logger.info(" ".join(entries)) |
|
|
|
def _get_batch_fmtstr(self, num_batches): |
|
num_digits = len(str(num_batches // 1)) |
|
fmt = "{:" + str(num_digits) + "d}" |
|
return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
|
|
|
|
|
def trainMetricGPU(output, target, threshold=0.35, pr_iou=0.5): |
|
assert (output.dim() in [2, 3, 4]) |
|
assert output.shape == target.shape |
|
output = output.flatten(1) |
|
target = target.flatten(1) |
|
output = torch.sigmoid(output) |
|
output[output < threshold] = 0. |
|
output[output >= threshold] = 1. |
|
|
|
inter = (output.bool() & target.bool()).sum(dim=1) |
|
union = (output.bool() | target.bool()).sum(dim=1) |
|
ious = inter / (union + 1e-6) |
|
|
|
iou = ious.mean() |
|
prec = (ious > pr_iou).float().mean() |
|
return 100. * iou, 100. * prec |
|
|
|
|
|
def ValMetricGPU(output, target, threshold=0.35): |
|
assert output.size(0) == 1 |
|
output = output.flatten(1) |
|
target = target.flatten(1) |
|
output = torch.sigmoid(output) |
|
output[output < threshold] = 0. |
|
output[output >= threshold] = 1. |
|
|
|
inter = (output.bool() & target.bool()).sum(dim=1) |
|
union = (output.bool() | target.bool()).sum(dim=1) |
|
ious = inter / (union + 1e-6) |
|
return ious |
|
|
|
|
|
def intersectionAndUnionGPU(output, target, K, threshold=0.5): |
|
|
|
assert (output.dim() in [1, 2, 3]) |
|
assert output.shape == target.shape |
|
output = output.view(-1) |
|
target = target.view(-1) |
|
|
|
output = torch.sigmoid(output) |
|
output[output < threshold] = 0. |
|
output[output >= threshold] = 1. |
|
|
|
intersection = output[output == target] |
|
area_intersection = torch.histc(intersection.float(), |
|
bins=K, |
|
min=0, |
|
max=K - 1) |
|
area_output = torch.histc(output.float(), bins=K, min=0, max=K - 1) |
|
area_target = torch.histc(target.float(), bins=K, min=0, max=K - 1) |
|
area_union = area_output + area_target - area_intersection |
|
return area_intersection[1], area_union[1] |
|
|
|
|
|
def group_weight(weight_group, module, lr): |
|
group_decay = [] |
|
group_no_decay = [] |
|
for m in module.modules(): |
|
if isinstance(m, nn.Linear): |
|
group_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, nn.modules.conv._ConvNd): |
|
group_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
elif isinstance(m, nn.modules.batchnorm._BatchNorm): |
|
if m.weight is not None: |
|
group_no_decay.append(m.weight) |
|
if m.bias is not None: |
|
group_no_decay.append(m.bias) |
|
assert len(list( |
|
module.parameters())) == len(group_decay) + len(group_no_decay) |
|
weight_group.append(dict(params=group_decay, lr=lr)) |
|
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) |
|
return weight_group |
|
|
|
|
|
def colorize(gray, palette): |
|
|
|
color = Image.fromarray(gray.astype(np.uint8)).convert('P') |
|
color.putpalette(palette) |
|
return color |
|
|
|
|
|
def find_free_port(): |
|
import socket |
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
sock.bind(("", 0)) |
|
port = sock.getsockname()[1] |
|
sock.close() |
|
|
|
return port |
|
|
|
|
|
def get_caller_name(depth=0): |
|
""" |
|
Args: |
|
depth (int): Depth of caller conext, use 0 for caller depth. |
|
Default value: 0. |
|
|
|
Returns: |
|
str: module name of the caller |
|
""" |
|
|
|
frame = inspect.currentframe().f_back |
|
for _ in range(depth): |
|
frame = frame.f_back |
|
|
|
return frame.f_globals["__name__"] |
|
|
|
|
|
class StreamToLoguru: |
|
""" |
|
stream object that redirects writes to a logger instance. |
|
""" |
|
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): |
|
""" |
|
Args: |
|
level(str): log level string of loguru. Default value: "INFO". |
|
caller_names(tuple): caller names of redirected module. |
|
Default value: (apex, pycocotools). |
|
""" |
|
self.level = level |
|
self.linebuf = "" |
|
self.caller_names = caller_names |
|
|
|
def write(self, buf): |
|
full_name = get_caller_name(depth=1) |
|
module_name = full_name.rsplit(".", maxsplit=-1)[0] |
|
if module_name in self.caller_names: |
|
for line in buf.rstrip().splitlines(): |
|
|
|
logger.opt(depth=2).log(self.level, line.rstrip()) |
|
else: |
|
sys.__stdout__.write(buf) |
|
|
|
def flush(self): |
|
pass |
|
|
|
|
|
def redirect_sys_output(log_level="INFO"): |
|
redirect_logger = StreamToLoguru(log_level) |
|
sys.stderr = redirect_logger |
|
sys.stdout = redirect_logger |
|
|
|
|
|
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): |
|
"""setup logger for training and testing. |
|
Args: |
|
save_dir(str): location to save log file |
|
distributed_rank(int): device rank when multi-gpu environment |
|
filename (string): log save name. |
|
mode(str): log file write mode, `append` or `override`. default is `a`. |
|
|
|
Return: |
|
logger instance. |
|
""" |
|
loguru_format = ( |
|
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " |
|
"<level>{level: <8}</level> | " |
|
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>") |
|
|
|
logger.remove() |
|
save_file = os.path.join(save_dir, filename) |
|
if mode == "o" and os.path.exists(save_file): |
|
os.remove(save_file) |
|
|
|
if distributed_rank == 0: |
|
logger.add( |
|
sys.stderr, |
|
format=loguru_format, |
|
level="INFO", |
|
enqueue=True, |
|
) |
|
logger.add(save_file) |
|
|
|
|
|
redirect_sys_output("INFO") |
|
|