Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from .ema import ExponentialMovingAverage | |
| def load_model_weights(model, ckpt_path, use_ema=True, device='cuda:0'): | |
| """ | |
| Load weights of a model from a checkpoint file. | |
| Args: | |
| model (torch.nn.Module): The model to load weights into. | |
| ckpt_path (str): Path to the checkpoint file. | |
| use_ema (bool): Whether to use Exponential Moving Average (EMA) weights if available. | |
| """ | |
| checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)}) | |
| total_iter = checkpoint.get('total_it', 0) | |
| if "model_ema" in checkpoint and use_ema: | |
| ema_key = next(iter(checkpoint["model_ema"])) | |
| if ('module' in ema_key) or ('n_averaged' in ema_key): | |
| model = ExponentialMovingAverage(model, decay=1.0) | |
| model.load_state_dict(checkpoint["model_ema"], strict=True) | |
| if ('module' in ema_key) or ('n_averaged' in ema_key): | |
| model = model.module | |
| print(f'\nLoading EMA module model from {ckpt_path} with {total_iter} iterations') | |
| else: | |
| print(f'\nLoading EMA model from {ckpt_path} with {total_iter} iterations') | |
| else: | |
| model.load_state_dict(checkpoint['encoder'], strict=True) | |
| print(f'\nLoading model from {ckpt_path} with {total_iter} iterations') | |
| return total_iter |