Spaces:
Runtime error
Runtime error
File size: 4,445 Bytes
6b59850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import torch_geometric.utils
from omegaconf import OmegaConf, open_dict
from torch_geometric.utils import to_dense_adj, to_dense_batch
import torch
import omegaconf
import wandb
def create_folders(args):
try:
# os.makedirs('checkpoints')
os.makedirs('graphs')
os.makedirs('chains')
except OSError:
pass
try:
# os.makedirs('checkpoints/' + args.general.name)
os.makedirs('graphs/' + args.general.name)
os.makedirs('chains/' + args.general.name)
except OSError:
pass
def normalize(X, E, y, norm_values, norm_biases, node_mask):
X = (X - norm_biases[0]) / norm_values[0]
E = (E - norm_biases[1]) / norm_values[1]
y = (y - norm_biases[2]) / norm_values[2]
diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
E[diag] = 0
return PlaceHolder(X=X, E=E, y=y).mask(node_mask)
def unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=False):
"""
X : node features
E : edge features
y : global features`
norm_values : [norm value X, norm value E, norm value y]
norm_biases : same order
node_mask
"""
X = (X * norm_values[0] + norm_biases[0])
E = (E * norm_values[1] + norm_biases[1])
y = y * norm_values[2] + norm_biases[2]
return PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse)
def to_dense(x, edge_index, edge_attr, batch):
X, node_mask = to_dense_batch(x=x, batch=batch)
# node_mask = node_mask.float()
edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
# TODO: carefully check if setting node_mask as a bool breaks the continuous case
max_num_nodes = X.size(1)
E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
E = encode_no_edge(E)
return PlaceHolder(X=X, E=E, y=None), node_mask
def encode_no_edge(E):
assert len(E.shape) == 4
if E.shape[-1] == 0:
return E
no_edge = torch.sum(E, dim=3) == 0
first_elt = E[:, :, :, 0]
first_elt[no_edge] = 1
E[:, :, :, 0] = first_elt
diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
E[diag] = 0
return E
def update_config_with_new_keys(cfg, saved_cfg):
saved_general = saved_cfg.general
saved_train = saved_cfg.train
saved_model = saved_cfg.model
for key, val in saved_general.items():
OmegaConf.set_struct(cfg.general, True)
with open_dict(cfg.general):
if key not in cfg.general.keys():
setattr(cfg.general, key, val)
OmegaConf.set_struct(cfg.train, True)
with open_dict(cfg.train):
for key, val in saved_train.items():
if key not in cfg.train.keys():
setattr(cfg.train, key, val)
OmegaConf.set_struct(cfg.model, True)
with open_dict(cfg.model):
for key, val in saved_model.items():
if key not in cfg.model.keys():
setattr(cfg.model, key, val)
return cfg
class PlaceHolder:
def __init__(self, X, E, y):
self.X = X
self.E = E
self.y = y
def type_as(self, x: torch.Tensor):
""" Changes the device and dtype of X, E, y. """
self.X = self.X.type_as(x)
self.E = self.E.type_as(x)
self.y = self.y.type_as(x)
return self
def mask(self, node_mask, collapse=False):
x_mask = node_mask.unsqueeze(-1) # bs, n, 1
e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
if collapse:
self.X = torch.argmax(self.X, dim=-1)
self.E = torch.argmax(self.E, dim=-1)
self.X[node_mask == 0] = - 1
self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
else:
self.X = self.X * x_mask
self.E = self.E * e_mask1 * e_mask2
assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
return self
def setup_wandb(cfg):
config_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
kwargs = {'name': cfg.general.name, 'project': f'graph_ddm_{cfg.dataset.name}', 'config': config_dict,
'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': cfg.general.wandb}
wandb.init(**kwargs)
wandb.save('*.txt') |