import torch from src.obth_gnn import HGnn from src.dft_data_to_grphs import MaterialDS, MaterialMesh, MyTensor from src.utils import generate_heatmap from torch_geometric.loader import DataLoader import json device =torch.device('cpu') model_name_map = { "wizard_tb 0.3": "demo_model_1", "wizard 0.3": "demo_model_0" } data_map = { "aBN_01": 0, "aBN_02": 1, } def output_to_matrix(hii, hij, ij): mat_h = torch.zeros([len(hii), len(hii)]) mat_s = torch.zeros([len(hii), len(hii)]) for i, hi in enumerate(hii): mat_h[i][i] = hi[0] mat_s[i][i] = hi[1] for i, hx in enumerate(hij): mat_h[ij[0][i]][ij[1][i]] = hx[0] mat_s[ij[0][i]][ij[1][i]] = hx[1] return mat_h.detach().numpy(), mat_s.detach().numpy() def plot_mat(mat): fig = generate_heatmap(mat, "OutputFiles/img/mat.jpg", grid1_step=1, grid2_step=13) return fig def text_out(h_mat, s_mat,model_name, data_name ): file_content= {"h_mat":h_mat.tolist(), "s_mat":s_mat.tolist()} file_name=f"OutputFiles/{data_name}_{model_name}.json" with open(file_name, 'w') as json_file: json.dump(file_content, json_file, indent=4) return file_name def compute_mat(data_name, model_name): print("model_name",model_name) if model_name[0] == "wizard_tb 0.3": model = HGnn(edge_shape=51, node_shape=2, u_shape=10, embed_size=[20, 20, 10], ham_graph_emb=[7, 7, 7], n_blocks=3) model.load_state_dict(torch.load(f'Models/{model_name_map[model_name[0]]}.pt', map_location=device)) model.to(device) else: print("Model not in th th elist of available models") test_data = torch.load("DATA/demo-graph/train.pt") data_ = DataLoader(test_data, batch_size=1, shuffle=False, ) inputs = [k for k in data_] print("data_name",data_name) dn=data_map[data_name[0]] inputs=inputs[dn] targets = (inputs.onsite, inputs.hop) x = inputs.x.to(torch.float32) edge_index = inputs.edge_index.to(torch.int64) edge_attr = inputs.edge_attr.to(torch.float32) state = inputs.u.to(torch.float32) batch = inputs.batch bond_batch = inputs.bond_batch hii, hij, ij = model(x, edge_index, edge_attr, state, batch.to(device), bond_batch.to(device)) h_mat, s_mat = output_to_matrix(hii, hij, ij) h_plot = plot_mat(h_mat) s_plot = plot_mat(s_mat) file_rsp = text_out(h_mat, s_mat, model_name[0], data_name[0]) return h_plot, s_plot, file_rsp def upload_struct(): pass