Spaces:
Sleeping
Sleeping
import torch | |
from src.obth_gnn import HGnn | |
from src.dft_data_to_grphs import MaterialDS, MaterialMesh, MyTensor | |
from src.utils import generate_heatmap | |
from torch_geometric.loader import DataLoader | |
import json | |
device =torch.device('cpu') | |
model_name_map = { | |
"wizard_tb 0.3": "demo_model_1", | |
"wizard 0.3": "demo_model_0" | |
} | |
data_map = { | |
"aBN_01": 0, "aBN_02": 1, | |
} | |
def output_to_matrix(hii, hij, ij): | |
mat_h = torch.zeros([len(hii), len(hii)]) | |
mat_s = torch.zeros([len(hii), len(hii)]) | |
for i, hi in enumerate(hii): | |
mat_h[i][i] = hi[0] | |
mat_s[i][i] = hi[1] | |
for i, hx in enumerate(hij): | |
mat_h[ij[0][i]][ij[1][i]] = hx[0] | |
mat_s[ij[0][i]][ij[1][i]] = hx[1] | |
return mat_h.detach().numpy(), mat_s.detach().numpy() | |
def plot_mat(mat): | |
fig = generate_heatmap(mat, "OutputFiles/img/mat.jpg", grid1_step=1, grid2_step=13) | |
return fig | |
def text_out(h_mat, s_mat,model_name, data_name ): | |
file_content= {"h_mat":h_mat.tolist(), "s_mat":s_mat.tolist()} | |
file_name=f"OutputFiles/{data_name}_{model_name}.json" | |
with open(file_name, 'w') as json_file: | |
json.dump(file_content, json_file, indent=4) | |
return file_name | |
def compute_mat(data_name, model_name): | |
print("model_name",model_name) | |
if model_name[0] == "wizard_tb 0.3": | |
model = HGnn(edge_shape=51, | |
node_shape=2, | |
u_shape=10, | |
embed_size=[20, 20, 10], | |
ham_graph_emb=[7, 7, 7], | |
n_blocks=3) | |
model.load_state_dict(torch.load(f'Models/{model_name_map[model_name[0]]}.pt', map_location=device)) | |
model.to(device) | |
else: | |
print("Model not in th th elist of available models") | |
test_data = torch.load("DATA/demo-graph/train.pt") | |
data_ = DataLoader(test_data, batch_size=1, shuffle=False, ) | |
inputs = [k for k in data_] | |
print("data_name",data_name) | |
dn=data_map[data_name[0]] | |
inputs=inputs[dn] | |
targets = (inputs.onsite, inputs.hop) | |
x = inputs.x.to(torch.float32) | |
edge_index = inputs.edge_index.to(torch.int64) | |
edge_attr = inputs.edge_attr.to(torch.float32) | |
state = inputs.u.to(torch.float32) | |
batch = inputs.batch | |
bond_batch = inputs.bond_batch | |
hii, hij, ij = model(x, edge_index, edge_attr, state, batch.to(device), | |
bond_batch.to(device)) | |
h_mat, s_mat = output_to_matrix(hii, hij, ij) | |
h_plot = plot_mat(h_mat) | |
s_plot = plot_mat(s_mat) | |
file_rsp = text_out(h_mat, s_mat, model_name[0], data_name[0]) | |
return h_plot, s_plot, file_rsp | |
def upload_struct(): | |
pass | |