import os import torch import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt def save_img_and_npy(path, matrix): plt.imsave(path + ".png", matrix, origin="lower") def save_checkpoint(state, state_dict_only, path, target): torch.save(state, os.path.join(path, target + ".chkpnt")) if state_dict_only: # save just the weights torch.save(state["state_dict"], os.path.join(path, target + ".pth")) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class EarlyStopping(object): def __init__(self, mode="min", min_delta=0, patience=10): self.mode = mode self.min_delta = min_delta self.patience = patience self.best = None self.num_bad_epochs = 0 self.is_better = None self._init_is_better(mode, min_delta) if patience == 0: self.is_better = lambda a, b: True def step(self, metrics): if self.best is None: self.best = metrics return False if np.isnan(metrics): return True if self.is_better(metrics, self.best): self.num_bad_epochs = 0 self.best = metrics else: self.num_bad_epochs += 1 if self.num_bad_epochs >= self.patience: return True return False def _init_is_better(self, mode, min_delta): if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if mode == "min": self.is_better = lambda a, best: a < best - min_delta if mode == "max": self.is_better = lambda a, best: a > best + min_delta