Spaces:
Paused
Paused
| import os | |
| import numpy as np | |
| import torch | |
| from torch.autograd import Variable | |
| from pdb import set_trace as st | |
| from IPython import embed | |
| class BaseModel: | |
| def __init__(self): | |
| pass | |
| def name(self): | |
| return "BaseModel" | |
| def initialize(self, use_gpu=True, gpu_ids=[0]): | |
| self.use_gpu = use_gpu | |
| self.gpu_ids = gpu_ids | |
| def forward(self): | |
| pass | |
| def get_image_paths(self): | |
| pass | |
| def optimize_parameters(self): | |
| pass | |
| def get_current_visuals(self): | |
| return self.input | |
| def get_current_errors(self): | |
| return {} | |
| def save(self, label): | |
| pass | |
| # helper saving function that can be used by subclasses | |
| def save_network(self, network, path, network_label, epoch_label): | |
| save_filename = "%s_net_%s.pth" % (epoch_label, network_label) | |
| save_path = os.path.join(path, save_filename) | |
| torch.save(network.state_dict(), save_path) | |
| # helper loading function that can be used by subclasses | |
| def load_network(self, network, network_label, epoch_label): | |
| save_filename = "%s_net_%s.pth" % (epoch_label, network_label) | |
| save_path = os.path.join(self.save_dir, save_filename) | |
| print("Loading network from %s" % save_path) | |
| network.load_state_dict(torch.load(save_path)) | |
| def update_learning_rate(): | |
| pass | |
| def get_image_paths(self): | |
| return self.image_paths | |
| def save_done(self, flag=False): | |
| np.save(os.path.join(self.save_dir, "done_flag"), flag) | |
| np.savetxt( | |
| os.path.join(self.save_dir, "done_flag"), | |
| [ | |
| flag, | |
| ], | |
| fmt="%i", | |
| ) | |