HamiltonianMagic / interface_connection.py
AndreiVoicuT's picture
Update interface_connection.py
14b2be0 verified
import torch
from src.obth_gnn import HGnn
from src.dft_data_to_grphs import MaterialDS, MaterialMesh, MyTensor
from src.utils import generate_heatmap
from torch_geometric.loader import DataLoader
import json
device =torch.device('cpu')
model_name_map = {
"wizard_tb 0.3": "demo_model_1",
"wizard 0.3": "demo_model_0"
}
data_map = {
"aBN_01": 0, "aBN_02": 1,
}
def output_to_matrix(hii, hij, ij):
mat_h = torch.zeros([len(hii), len(hii)])
mat_s = torch.zeros([len(hii), len(hii)])
for i, hi in enumerate(hii):
mat_h[i][i] = hi[0]
mat_s[i][i] = hi[1]
for i, hx in enumerate(hij):
mat_h[ij[0][i]][ij[1][i]] = hx[0]
mat_s[ij[0][i]][ij[1][i]] = hx[1]
return mat_h.detach().numpy(), mat_s.detach().numpy()
def plot_mat(mat):
fig = generate_heatmap(mat, "OutputFiles/img/mat.jpg", grid1_step=1, grid2_step=13)
return fig
def text_out(h_mat, s_mat,model_name, data_name ):
file_content= {"h_mat":h_mat.tolist(), "s_mat":s_mat.tolist()}
file_name=f"OutputFiles/{data_name}_{model_name}.json"
with open(file_name, 'w') as json_file:
json.dump(file_content, json_file, indent=4)
return file_name
def compute_mat(data_name, model_name):
print("model_name",model_name)
if model_name[0] == "wizard_tb 0.3":
model = HGnn(edge_shape=51,
node_shape=2,
u_shape=10,
embed_size=[20, 20, 10],
ham_graph_emb=[7, 7, 7],
n_blocks=3)
model.load_state_dict(torch.load(f'Models/{model_name_map[model_name[0]]}.pt', map_location=device))
model.to(device)
else:
print("Model not in th th elist of available models")
test_data = torch.load("DATA/demo-graph/train.pt")
data_ = DataLoader(test_data, batch_size=1, shuffle=False, )
inputs = [k for k in data_]
print("data_name",data_name)
dn=data_map[data_name[0]]
inputs=inputs[dn]
targets = (inputs.onsite, inputs.hop)
x = inputs.x.to(torch.float32)
edge_index = inputs.edge_index.to(torch.int64)
edge_attr = inputs.edge_attr.to(torch.float32)
state = inputs.u.to(torch.float32)
batch = inputs.batch
bond_batch = inputs.bond_batch
hii, hij, ij = model(x, edge_index, edge_attr, state, batch.to(device),
bond_batch.to(device))
h_mat, s_mat = output_to_matrix(hii, hij, ij)
h_plot = plot_mat(h_mat)
s_plot = plot_mat(s_mat)
file_rsp = text_out(h_mat, s_mat, model_name[0], data_name[0])
return h_plot, s_plot, file_rsp
def upload_struct():
pass