Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import logging | |
import torch | |
import wandb | |
class CheckpointSaver: | |
def __init__(self, dirpath, run_name='', decreasing=True, top_n=5): | |
""" | |
dirpath: Directory path where to store all model weights | |
decreasing: If decreasing is `True`, then lower metric is better | |
top_n: Total number of models to track based on validation metric value | |
""" | |
if not os.path.exists(dirpath): os.makedirs(dirpath) | |
self.dirpath = dirpath | |
self.top_n = top_n | |
self.decreasing = decreasing | |
self.top_model_paths = [] | |
self.best_metric_val = np.Inf if decreasing else -np.Inf | |
self.run_name = run_name | |
def __call__(self, model, epoch, metric_val, optimizer, loss): | |
model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_{self.run_name}_epoch{epoch}.pt') | |
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val | |
if save: | |
logging.info( | |
f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.") | |
self.best_metric_val = metric_val | |
torch.save( | |
{ # Save our checkpoint loc | |
'epoch': epoch, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'loss': loss, | |
}, model_path) | |
self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val) | |
self.top_model_paths.append({'path': model_path, 'score': metric_val}) | |
self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing) | |
if len(self.top_model_paths) > self.top_n: | |
self.cleanup() | |
def log_artifact(self, filename, model_path, metric_val): | |
artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val}) | |
artifact.add_file(model_path) | |
wandb.run.log_artifact(artifact) | |
def cleanup(self): | |
to_remove = self.top_model_paths[self.top_n:] | |
logging.info(f"Removing extra models.. {to_remove}") | |
for o in to_remove: | |
os.remove(o['path']) | |
self.top_model_paths = self.top_model_paths[:self.top_n] | |