File size: 2,633 Bytes
cd71bd3
 
 
 
 
 
 
 
 
14b2be0
 
cd71bd3
 
0d13fe8
cd71bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14b2be0
cd71bd3
 
 
 
 
 
 
 
 
 
 
2b97ade
cd71bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from src.obth_gnn import HGnn
from src.dft_data_to_grphs import MaterialDS, MaterialMesh, MyTensor
from src.utils import generate_heatmap
from torch_geometric.loader import DataLoader
import json

device =torch.device('cpu')
model_name_map = {
    "wizard_tb 0.3": "demo_model_1",
    "wizard 0.3": "demo_model_0"
}
data_map = {
    "aBN_01": 0, "aBN_02": 1,
}


def output_to_matrix(hii, hij, ij):
    mat_h = torch.zeros([len(hii), len(hii)])
    mat_s = torch.zeros([len(hii), len(hii)])
    for i, hi in enumerate(hii):
        mat_h[i][i] = hi[0]
        mat_s[i][i] = hi[1]
    for i, hx in enumerate(hij):
        mat_h[ij[0][i]][ij[1][i]] = hx[0]
        mat_s[ij[0][i]][ij[1][i]] = hx[1]

    return mat_h.detach().numpy(), mat_s.detach().numpy()

def plot_mat(mat):
    fig = generate_heatmap(mat, "OutputFiles/img/mat.jpg", grid1_step=1, grid2_step=13)

    return fig


def text_out(h_mat, s_mat,model_name, data_name ):
    file_content= {"h_mat":h_mat.tolist(), "s_mat":s_mat.tolist()}
    file_name=f"OutputFiles/{data_name}_{model_name}.json"

    with open(file_name, 'w') as json_file:
        json.dump(file_content, json_file, indent=4)
    return file_name


def compute_mat(data_name, model_name):
    print("model_name",model_name)
    if model_name[0] == "wizard_tb 0.3":
        model = HGnn(edge_shape=51,
                     node_shape=2,
                     u_shape=10,
                     embed_size=[20, 20, 10],
                     ham_graph_emb=[7, 7, 7],
                     n_blocks=3)
        model.load_state_dict(torch.load(f'Models/{model_name_map[model_name[0]]}.pt', map_location=device))
        model.to(device)
    else:
        print("Model not in th th elist of available models")

    test_data = torch.load("DATA/demo-graph/train.pt")
    data_ = DataLoader(test_data, batch_size=1, shuffle=False, )
    inputs = [k for k in data_]
    print("data_name",data_name)
    dn=data_map[data_name[0]]
    inputs=inputs[dn]

    targets = (inputs.onsite, inputs.hop)
    x = inputs.x.to(torch.float32)
    edge_index = inputs.edge_index.to(torch.int64)
    edge_attr = inputs.edge_attr.to(torch.float32)
    state = inputs.u.to(torch.float32)
    batch = inputs.batch
    bond_batch = inputs.bond_batch

    hii, hij, ij = model(x, edge_index, edge_attr, state, batch.to(device),
                         bond_batch.to(device))

    h_mat, s_mat = output_to_matrix(hii, hij, ij)

    h_plot = plot_mat(h_mat)
    s_plot = plot_mat(s_mat)

    file_rsp = text_out(h_mat, s_mat, model_name[0], data_name[0])

    return h_plot, s_plot, file_rsp


def upload_struct():
    pass