|
import glob |
|
import os |
|
import re |
|
import torch |
|
|
|
|
|
def get_last_checkpoint(work_dir, steps=None): |
|
checkpoint = None |
|
last_ckpt_path = None |
|
if work_dir.endswith(".ckpt"): |
|
ckpt_paths = [work_dir] |
|
else: |
|
ckpt_paths = get_all_ckpts(work_dir, steps) |
|
if len(ckpt_paths) > 0: |
|
last_ckpt_path = ckpt_paths[0] |
|
checkpoint = torch.load(last_ckpt_path, map_location='cpu') |
|
return checkpoint, last_ckpt_path |
|
|
|
|
|
def get_all_ckpts(work_dir, steps=None): |
|
if steps is None: |
|
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' |
|
else: |
|
ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' |
|
return sorted(glob.glob(ckpt_path_pattern), |
|
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) |
|
|
|
|
|
def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True, steps=None, verbose=True): |
|
if os.path.isfile(ckpt_base_dir): |
|
base_dir = os.path.dirname(ckpt_base_dir) |
|
ckpt_path = ckpt_base_dir |
|
checkpoint = torch.load(ckpt_base_dir, map_location='cpu') |
|
else: |
|
base_dir = ckpt_base_dir |
|
checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir, steps) |
|
if checkpoint is not None: |
|
state_dict = checkpoint["state_dict"] |
|
if len([k for k in state_dict.keys() if '.' in k]) > 0: |
|
state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() |
|
if k.startswith(f'{model_name}.')} |
|
else: |
|
if '.' not in model_name: |
|
state_dict = state_dict[model_name] |
|
else: |
|
base_model_name = model_name.split('.')[0] |
|
rest_model_name = model_name[len(base_model_name) + 1:] |
|
state_dict = { |
|
k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() |
|
if k.startswith(f'{rest_model_name}.')} |
|
if not strict: |
|
cur_model_state_dict = cur_model.state_dict() |
|
unmatched_keys = [] |
|
for key, param in state_dict.items(): |
|
if key in cur_model_state_dict: |
|
new_param = cur_model_state_dict[key] |
|
if new_param.shape != param.shape: |
|
unmatched_keys.append(key) |
|
print("| Unmatched keys (shape mismatch): ", key, new_param.shape, param.shape) |
|
else: |
|
print(f"Skipping unmatched keys (in state_dict but not in cur_model): {key}") |
|
for key in unmatched_keys: |
|
if verbose: |
|
print(f"Del unmatched keys {key}") |
|
del state_dict[key] |
|
if hasattr(cur_model, 'load_state_dict'): |
|
cur_model.load_state_dict(state_dict, strict=strict) |
|
else: |
|
cur_model.data = state_dict |
|
print(f"| load '{model_name}' from '{ckpt_path}', strict={strict}") |
|
else: |
|
e_msg = f"| ckpt not found in {base_dir}." |
|
if force: |
|
assert False, e_msg |
|
else: |
|
print(e_msg) |
|
|
|
def restore_weights(task_ref, checkpoint): |
|
|
|
for k, v in checkpoint['state_dict'].items(): |
|
if hasattr(task_ref, k): |
|
getattr(task_ref, k).load_state_dict(v, strict=True) |
|
print(f"| resotred {k} from pretrained checkpoints") |
|
else: |
|
print(f"| the checkpoint has unmatched keys {k}") |
|
|
|
def restore_opt_state(optimizers, checkpoint): |
|
|
|
optimizer_states = checkpoint['optimizer_states'] |
|
for optimizer, opt_state in zip(optimizers, optimizer_states): |
|
if optimizer is None: |
|
return |
|
try: |
|
optimizer.load_state_dict(opt_state) |
|
|
|
|
|
|
|
|
|
|
|
|
|
except ValueError: |
|
print("| WARMING: optimizer parameters not match !!!") |
|
|