SRGAN / utils.py
Thibaud Cheruy
New: Add SRGAN Space
92d45d2
# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import shutil
from enum import Enum
from typing import Any
import torch
from torch import nn
from torch.nn import Module
from torch.optim import Optimizer
__all__ = [
"load_state_dict", "make_directory", "save_checkpoint",
"Summary", "AverageMeter", "ProgressMeter"
]
def load_state_dict(
model: nn.Module,
model_weights_path: str,
ema_model: nn.Module = None,
optimizer: torch.optim.Optimizer = None,
scheduler: torch.optim.lr_scheduler = None,
load_mode: str = None,
) -> tuple[Module, Module, Any, Any, Any, Optimizer | None, Any] | tuple[Module, Any, Any, Any, Optimizer | None, Any] | Module:
# Load model weights
checkpoint = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
if load_mode == "resume":
# Restore the parameters in the training node to this point
start_epoch = checkpoint["epoch"]
best_psnr = checkpoint["best_psnr"]
best_ssim = checkpoint["best_ssim"]
# Load model state dict. Extract the fitted model weights
model_state_dict = model.state_dict()
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict.keys()}
# Overwrite the model weights to the current model (base model)
model_state_dict.update(state_dict)
model.load_state_dict(model_state_dict)
# Load the optimizer model
optimizer.load_state_dict(checkpoint["optimizer"])
if scheduler is not None:
# Load the scheduler model
scheduler.load_state_dict(checkpoint["scheduler"])
if ema_model is not None:
# Load ema model state dict. Extract the fitted model weights
ema_model_state_dict = ema_model.state_dict()
ema_state_dict = {k: v for k, v in checkpoint["ema_state_dict"].items() if k in ema_model_state_dict.keys()}
# Overwrite the model weights to the current model (ema model)
ema_model_state_dict.update(ema_state_dict)
ema_model.load_state_dict(ema_model_state_dict)
return model, ema_model, start_epoch, best_psnr, best_ssim, optimizer, scheduler
else:
# Load model state dict. Extract the fitted model weights
model_state_dict = model.state_dict()
state_dict = {k: v for k, v in checkpoint["state_dict"].items() if
k in model_state_dict.keys() and v.size() == model_state_dict[k].size()}
# Overwrite the model weights to the current model
model_state_dict.update(state_dict)
model.load_state_dict(model_state_dict)
return model
def make_directory(dir_path: str) -> None:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def save_checkpoint(
state_dict: dict,
file_name: str,
samples_dir: str,
results_dir: str,
best_file_name: str,
last_file_name: str,
is_best: bool = False,
is_last: bool = False,
) -> None:
checkpoint_path = os.path.join(samples_dir, file_name)
torch.save(state_dict, checkpoint_path)
if is_best:
shutil.copyfile(checkpoint_path, os.path.join(results_dir, best_file_name))
if is_last:
shutil.copyfile(checkpoint_path, os.path.join(results_dir, last_file_name))
class Summary(Enum):
NONE = 0
AVERAGE = 1
SUM = 2
COUNT = 3
class AverageMeter(object):
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
self.name = name
self.fmt = fmt
self.summary_type = summary_type
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def summary(self):
if self.summary_type is Summary.NONE:
fmtstr = ""
elif self.summary_type is Summary.AVERAGE:
fmtstr = "{name} {avg:.2f}"
elif self.summary_type is Summary.SUM:
fmtstr = "{name} {sum:.2f}"
elif self.summary_type is Summary.COUNT:
fmtstr = "{name} {count:.2f}"
else:
raise ValueError(f"Invalid summary type {self.summary_type}")
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def display_summary(self):
entries = [" *"]
entries += [meter.summary() for meter in self.meters]
print(" ".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"