Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import logging | |
import os | |
import sys | |
from tensorboardX import SummaryWriter | |
def setup_logger(name, save_dir, distributed_rank=0): | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.DEBUG) | |
# don't log results for the non-master process | |
if distributed_rank > 0: | |
return logger | |
ch = logging.StreamHandler(stream=sys.stdout) | |
ch.setLevel(logging.DEBUG) | |
formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") | |
ch.setFormatter(formatter) | |
logger.addHandler(ch) | |
if save_dir: | |
fh = logging.FileHandler(os.path.join(save_dir, "log.txt")) | |
fh.setLevel(logging.DEBUG) | |
fh.setFormatter(formatter) | |
logger.addHandler(fh) | |
return logger | |
class Logger(object): | |
def __init__(self, log_dir, distributed_rank=0): | |
"""Create a summary writer logging to log_dir.""" | |
self.distributed_rank = distributed_rank | |
if distributed_rank == 0: | |
self.writer = SummaryWriter(log_dir) | |
def scalar_summary(self, tag, value, step): | |
"""Log a scalar variable.""" | |
if self.distributed_rank == 0: | |
self.writer.add_scalar(tag, value, step) | |