denoising / checkpoing_saver.py
BorisovMaksim's picture
add try except for calculating pesq scores
95d8ea8
raw
history blame
2.41 kB
import os
import numpy as np
import logging
import torch
import wandb
class CheckpointSaver:
def __init__(self, dirpath, run_name='', decreasing=True, top_n=5):
"""
dirpath: Directory path where to store all model weights
decreasing: If decreasing is `True`, then lower metric is better
top_n: Total number of models to track based on validation metric value
"""
if not os.path.exists(dirpath): os.makedirs(dirpath)
self.dirpath = dirpath
self.top_n = top_n
self.decreasing = decreasing
self.top_model_paths = []
self.best_metric_val = np.Inf if decreasing else -np.Inf
self.run_name = run_name
def __call__(self, model, epoch, metric_val, optimizer, loss):
model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_{self.run_name}_epoch{epoch}.pt')
save = metric_val < self.best_metric_val if self.decreasing else metric_val > self.best_metric_val
if save:
logging.info(
f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}, & logging model weights to W&B.")
self.best_metric_val = metric_val
torch.save(
{ # Save our checkpoint loc
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, model_path)
self.log_artifact(f'model-ckpt-epoch-{epoch}.pt', model_path, metric_val)
self.top_model_paths.append({'path': model_path, 'score': metric_val})
self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
if len(self.top_model_paths) > self.top_n:
self.cleanup()
def log_artifact(self, filename, model_path, metric_val):
artifact = wandb.Artifact(filename, type='model', metadata={'Validation score': metric_val})
artifact.add_file(model_path)
wandb.run.log_artifact(artifact)
def cleanup(self):
to_remove = self.top_model_paths[self.top_n:]
logging.info(f"Removing extra models.. {to_remove}")
for o in to_remove:
os.remove(o['path'])
self.top_model_paths = self.top_model_paths[:self.top_n]