Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from graph_decoder.diffusion_model import GraphDiT | |
| # model_state = load_model() | |
| # generate_graph(2.5, 15.4, 21.0, 1.5, 2.8, 2, 0, 1, model_state, 50) | |
| def count_parameters(model): | |
| r""" | |
| Returns the number of trainable parameters and number of all parameters in the model. | |
| """ | |
| trainable_params, all_param = 0, 0 | |
| for param in model.parameters(): | |
| num_params = param.numel() | |
| all_param += num_params | |
| if param.requires_grad: | |
| trainable_params += num_params | |
| return trainable_params, all_param | |
| def load_graph_decoder(path='model_labeled'): | |
| model_config_path = f"{path}/config.yaml" | |
| data_info_path = f"{path}/data.meta.json" | |
| model = GraphDiT( | |
| model_config_path=model_config_path, | |
| data_info_path=data_info_path, | |
| # model_dtype=torch.float16, | |
| model_dtype=torch.float32, | |
| ) | |
| model.init_model(path) | |
| model.disable_grads() | |
| # trainable_params, all_param = count_parameters(model) | |
| # param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( | |
| # path, trainable_params, all_param, 100 * trainable_params / all_param | |
| # ) | |
| # print(param_stats) | |
| return model | |