Spaces:
Running
on
Zero
Running
on
Zero
| from .models.autoencoders import create_autoencoder_from_config | |
| import os | |
| import json | |
| import torch | |
| from torch.nn.utils import remove_weight_norm | |
| def remove_all_weight_norm(model): | |
| for name, module in model.named_modules(): | |
| if hasattr(module, 'weight_g'): | |
| remove_weight_norm(module) | |
| def load_vae(ckpt_path, remove_weight_norm=False): | |
| config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json') | |
| # Load the model configuration | |
| with open(config_file) as f: | |
| model_config = json.load(f) | |
| # Create the model from the configuration | |
| model = create_autoencoder_from_config(model_config) | |
| # Load the state dictionary from the checkpoint | |
| model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict'] | |
| # Strip the "autoencoder." prefix from the keys | |
| model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")} | |
| # Load the state dictionary into the model | |
| model.load_state_dict(model_dict) | |
| # Remove weight normalization | |
| if remove_weight_norm: | |
| remove_all_weight_norm(model) | |
| # Set the model to evaluation mode | |
| model.eval() | |
| return model | |