Spaces:
Sleeping
Sleeping
""" | |
Based on rllab's logger. | |
https://github.com/rll/rllab | |
""" | |
from enum import Enum | |
from contextlib import contextmanager | |
import numpy as np | |
import os | |
import os.path as osp | |
import sys | |
import datetime | |
import dateutil.tz | |
import csv | |
import json | |
import pickle | |
import errno | |
import torch | |
from src.rlkit.core.tabulate import tabulate | |
class TerminalTablePrinter(object): | |
def __init__(self): | |
self.headers = None | |
self.tabulars = [] | |
def print_tabular(self, new_tabular): | |
if self.headers is None: | |
self.headers = [x[0] for x in new_tabular] | |
else: | |
assert len(self.headers) == len(new_tabular) | |
self.tabulars.append([x[1] for x in new_tabular]) | |
self.refresh() | |
def refresh(self): | |
import os | |
rows, columns = os.popen('stty size', 'r').read().split() | |
tabulars = self.tabulars[-(int(rows) - 3):] | |
sys.stdout.write("\x1b[2J\x1b[H") | |
sys.stdout.write(tabulate(tabulars, self.headers)) | |
sys.stdout.write("\n") | |
class MyEncoder(json.JSONEncoder): | |
def default(self, o): | |
if isinstance(o, type): | |
return {'$class': o.__module__ + "." + o.__name__} | |
elif isinstance(o, Enum): | |
return { | |
'$enum': o.__module__ + "." + o.__class__.__name__ + '.' + o.name | |
} | |
elif callable(o): | |
return { | |
'$function': o.__module__ + "." + o.__name__ | |
} | |
return json.JSONEncoder.default(self, o) | |
def mkdir_p(path): | |
try: | |
os.makedirs(path) | |
except OSError as exc: # Python >2.5 | |
if exc.errno == errno.EEXIST and os.path.isdir(path): | |
pass | |
else: | |
raise | |
class Logger(object): | |
def __init__(self): | |
self._prefixes = [] | |
self._prefix_str = '' | |
self._tabular_prefixes = [] | |
self._tabular_prefix_str = '' | |
self._tabular = [] | |
self._text_outputs = [] | |
self._tabular_outputs = [] | |
self._text_fds = {} | |
self._tabular_fds = {} | |
self._tabular_header_written = set() | |
self._snapshot_dir = None | |
self._snapshot_mode = 'all' | |
self._snapshot_gap = 1 | |
self._log_tabular_only = False | |
self._header_printed = False | |
self.table_printer = TerminalTablePrinter() | |
def reset(self): | |
self.__init__() | |
def _add_output(self, file_name, arr, fds, mode='a'): | |
if file_name not in arr: | |
mkdir_p(os.path.dirname(file_name)) | |
arr.append(file_name) | |
fds[file_name] = open(file_name, mode) | |
def _remove_output(self, file_name, arr, fds): | |
if file_name in arr: | |
fds[file_name].close() | |
del fds[file_name] | |
arr.remove(file_name) | |
def push_prefix(self, prefix): | |
self._prefixes.append(prefix) | |
self._prefix_str = ''.join(self._prefixes) | |
def add_text_output(self, file_name): | |
self._add_output(file_name, self._text_outputs, self._text_fds, | |
mode='a') | |
def remove_text_output(self, file_name): | |
self._remove_output(file_name, self._text_outputs, self._text_fds) | |
def add_tabular_output(self, file_name, relative_to_snapshot_dir=False): | |
if relative_to_snapshot_dir: | |
file_name = osp.join(self._snapshot_dir, file_name) | |
self._add_output(file_name, self._tabular_outputs, self._tabular_fds, | |
mode='w') | |
def remove_tabular_output(self, file_name, relative_to_snapshot_dir=False): | |
if relative_to_snapshot_dir: | |
file_name = osp.join(self._snapshot_dir, file_name) | |
if self._tabular_fds[file_name] in self._tabular_header_written: | |
self._tabular_header_written.remove(self._tabular_fds[file_name]) | |
self._remove_output(file_name, self._tabular_outputs, self._tabular_fds) | |
def set_snapshot_dir(self, dir_name): | |
self._snapshot_dir = dir_name | |
def get_snapshot_dir(self, ): | |
return self._snapshot_dir | |
def get_snapshot_mode(self, ): | |
return self._snapshot_mode | |
def set_snapshot_mode(self, mode): | |
self._snapshot_mode = mode | |
def get_snapshot_gap(self, ): | |
return self._snapshot_gap | |
def set_snapshot_gap(self, gap): | |
self._snapshot_gap = gap | |
def set_log_tabular_only(self, log_tabular_only): | |
self._log_tabular_only = log_tabular_only | |
def get_log_tabular_only(self, ): | |
return self._log_tabular_only | |
def log(self, s, with_prefix=True, with_timestamp=True): | |
out = s | |
if with_prefix: | |
out = self._prefix_str + out | |
if with_timestamp: | |
now = datetime.datetime.now(dateutil.tz.tzlocal()) | |
timestamp = now.strftime('%Y-%m-%d %H:%M:%S.%f %Z') | |
out = "%s | %s" % (timestamp, out) | |
if not self._log_tabular_only: | |
# Also log to stdout | |
print(out) | |
for fd in list(self._text_fds.values()): | |
fd.write(out + '\n') | |
fd.flush() | |
sys.stdout.flush() | |
def record_tabular(self, key, val): | |
self._tabular.append((self._tabular_prefix_str + str(key), str(val))) | |
def record_dict(self, d, prefix=None): | |
if prefix is not None: | |
self.push_tabular_prefix(prefix) | |
for k, v in d.items(): | |
self.record_tabular(k, v) | |
if prefix is not None: | |
self.pop_tabular_prefix() | |
def push_tabular_prefix(self, key): | |
self._tabular_prefixes.append(key) | |
self._tabular_prefix_str = ''.join(self._tabular_prefixes) | |
def pop_tabular_prefix(self, ): | |
del self._tabular_prefixes[-1] | |
self._tabular_prefix_str = ''.join(self._tabular_prefixes) | |
def save_extra_data(self, data, file_name='extra_data.pkl', mode='joblib'): | |
""" | |
Data saved here will always override the last entry | |
:param data: Something pickle'able. | |
""" | |
file_name = osp.join(self._snapshot_dir, file_name) | |
if mode == 'joblib': | |
import joblib | |
joblib.dump(data, file_name, compress=3) | |
elif mode == 'pickle': | |
pickle.dump(data, open(file_name, "wb")) | |
else: | |
raise ValueError("Invalid mode: {}".format(mode)) | |
return file_name | |
def get_table_dict(self, ): | |
return dict(self._tabular) | |
def get_table_key_set(self, ): | |
return set(key for key, value in self._tabular) | |
def prefix(self, key): | |
self.push_prefix(key) | |
try: | |
yield | |
finally: | |
self.pop_prefix() | |
def tabular_prefix(self, key): | |
self.push_tabular_prefix(key) | |
yield | |
self.pop_tabular_prefix() | |
def log_variant(self, log_file, variant_data): | |
mkdir_p(os.path.dirname(log_file)) | |
with open(log_file, "w") as f: | |
json.dump(variant_data, f, indent=2, sort_keys=True, cls=MyEncoder) | |
def record_tabular_misc_stat(self, key, values, placement='back'): | |
if placement == 'front': | |
prefix = "" | |
suffix = key | |
else: | |
prefix = key | |
suffix = "" | |
if len(values) > 0: | |
self.record_tabular(prefix + "Average" + suffix, np.average(values)) | |
self.record_tabular(prefix + "Std" + suffix, np.std(values)) | |
self.record_tabular(prefix + "Median" + suffix, np.median(values)) | |
self.record_tabular(prefix + "Min" + suffix, np.min(values)) | |
self.record_tabular(prefix + "Max" + suffix, np.max(values)) | |
else: | |
self.record_tabular(prefix + "Average" + suffix, np.nan) | |
self.record_tabular(prefix + "Std" + suffix, np.nan) | |
self.record_tabular(prefix + "Median" + suffix, np.nan) | |
self.record_tabular(prefix + "Min" + suffix, np.nan) | |
self.record_tabular(prefix + "Max" + suffix, np.nan) | |
def dump_tabular(self, *args, **kwargs): | |
wh = kwargs.pop("write_header", None) | |
if len(self._tabular) > 0: | |
if self._log_tabular_only: | |
self.table_printer.print_tabular(self._tabular) | |
else: | |
for line in tabulate(self._tabular).split('\n'): | |
self.log(line, *args, **kwargs) | |
tabular_dict = dict(self._tabular) | |
# Also write to the csv files | |
# This assumes that the keys in each iteration won't change! | |
for tabular_fd in list(self._tabular_fds.values()): | |
writer = csv.DictWriter(tabular_fd, | |
fieldnames=list(tabular_dict.keys())) | |
if wh or ( | |
wh is None and tabular_fd not in self._tabular_header_written): | |
writer.writeheader() | |
self._tabular_header_written.add(tabular_fd) | |
writer.writerow(tabular_dict) | |
tabular_fd.flush() | |
del self._tabular[:] | |
def pop_prefix(self, ): | |
del self._prefixes[-1] | |
self._prefix_str = ''.join(self._prefixes) | |
def save_itr_params(self, itr, params): | |
if self._snapshot_dir: | |
if self._snapshot_mode == 'all': | |
file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) | |
torch.save(params, file_name) | |
elif self._snapshot_mode == 'last': | |
# override previous params | |
file_name = osp.join(self._snapshot_dir, 'params.pkl') | |
torch.save(params, file_name) | |
elif self._snapshot_mode == "gap": | |
if itr % self._snapshot_gap == 0: | |
file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) | |
torch.save(params, file_name) | |
elif self._snapshot_mode == "gap_and_last": | |
if itr % self._snapshot_gap == 0: | |
file_name = osp.join(self._snapshot_dir, 'itr_%d.pkl' % itr) | |
torch.save(params, file_name) | |
file_name = osp.join(self._snapshot_dir, 'params.pkl') | |
torch.save(params, file_name) | |
elif self._snapshot_mode == 'none': | |
pass | |
else: | |
raise NotImplementedError | |
logger = Logger() | |