Spaces:
Runtime error
Runtime error
import torch | |
from distributions import DistributionNodes | |
from utils import to_dense | |
from torch_geometric.loader import DataLoader | |
from torch_geometric.data import Data | |
from torch_geometric.utils import remove_self_loops, to_undirected | |
import os | |
from sentence_transformers import SentenceTransformer | |
import random | |
def arrange_data(adj_matrix, cond_emb, ind): | |
n_nodes = adj_matrix.shape[0] | |
edge_index = adj_matrix.nonzero().t() | |
edge_attr = torch.tensor([[0, 1] for _ in range(edge_index.shape[1])]) | |
edge_index, edge_attr = to_undirected(edge_index, edge_attr, n_nodes, reduce = 'mean') | |
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) | |
x = torch.ones((n_nodes, 1)) | |
y = torch.empty(1, 0) | |
cond_emb = torch.tensor(cond_emb).unsqueeze(0) | |
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, prompt_id = torch.tensor(ind), cond_emb = cond_emb) | |
def load_dataset_cc(dataname, batch_size, hydra_path, condition): | |
domains = ['cc_high', 'cc_medium', 'cc_low'] | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
cond_embs = model.encode(condition) | |
train_data, val_data, test_data = [], [], [] | |
if dataname in domains: #only for test | |
train_d = torch.load(f'{hydra_path}/graphs/{dataname}/train.pt') | |
val_d = torch.load(f'{hydra_path}/graphs/{dataname}/val.pt') | |
test_d = torch.load(f'{hydra_path}/graphs/{dataname}/test.pt') | |
train_indices = torch.load(f'{hydra_path}/graphs/{dataname}/train_indices.pt') | |
val_indices = torch.load(f'{hydra_path}/graphs/{dataname}/val_indices.pt') | |
test_indices = torch.load(f'{hydra_path}/graphs/{dataname}/test_indices.pt') | |
with open(f'{hydra_path}/graphs/{dataname}/text_prompt_order.txt', 'r') as f: | |
text_prompt = f.readlines() | |
text_prompt = [x.strip() for x in text_prompt] | |
# text_prompt = ['1111111shgowhgo234o234']*10000 | |
print(text_prompt[0]) | |
text_embs = model.encode(text_prompt) | |
cond_embs = torch.tensor(text_embs) | |
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)]) | |
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)]) | |
if dataname != 'eco': | |
# test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)] | |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)] | |
else: | |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)] | |
print('Size of dataset', len(train_data), len(val_data), len(test_data)) | |
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True) | |
val_loader = DataLoader(val_data, batch_size = batch_size, shuffle=False) | |
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False) | |
return train_loader, val_loader, test_loader, train_data, val_data, test_data, text_embs.shape[1], torch.tensor(cond_embs) | |
def load_dataset_deg(dataname, batch_size, hydra_path, condition): | |
domains = ['deg_high', 'deg_medium', 'deg_low'] | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
cond_embs = model.encode(condition) | |
for domain in domains: | |
if not os.path.exists(f'{hydra_path}/graphs/{domain}/train.pt'): | |
data = torch.load(f'{hydra_path}/graphs/{domain}/{domain}.pt') | |
#fix seed | |
torch.manual_seed(0) | |
#random permute and split | |
n = len(data) | |
indices = torch.randperm(n) | |
if domain == 'eco': | |
train_indices = indices[:4].repeat(50) | |
val_indices = indices[4:5].repeat(50) | |
test_indices = indices[5:] | |
else: | |
train_indices = indices[:int(0.7 * n)] | |
val_indices = indices[int(0.7 * n):int(0.8 * n)] | |
test_indices = indices[int(0.8 * n):] | |
train_data = [data[_] for _ in train_indices] | |
val_data = [data[_] for _ in val_indices] | |
test_data = [data[_] for _ in test_indices] | |
torch.save(train_indices, f'{hydra_path}/graphs/{domain}/train_indices.pt') | |
torch.save(val_indices, f'{hydra_path}/graphs/{domain}/val_indices.pt') | |
torch.save(test_indices, f'{hydra_path}/graphs/{domain}/test_indices.pt') | |
torch.save(train_data, f'{hydra_path}/graphs/{domain}/train.pt') | |
torch.save(val_data, f'{hydra_path}/graphs/{domain}/val.pt') | |
torch.save(test_data, f'{hydra_path}/graphs/{domain}/test.pt') | |
train_data, val_data, test_data = [], [], [] | |
if dataname in domains: #only for test | |
train_d = torch.load(f'{hydra_path}/graphs/{dataname}/train.pt') | |
val_d = torch.load(f'{hydra_path}/graphs/{dataname}/val.pt') | |
test_d = torch.load(f'{hydra_path}/graphs/{dataname}/test.pt') | |
train_indices = torch.load(f'{hydra_path}/graphs/{dataname}/train_indices.pt') | |
val_indices = torch.load(f'{hydra_path}/graphs/{dataname}/val_indices.pt') | |
test_indices = torch.load(f'{hydra_path}/graphs/{dataname}/test_indices.pt') | |
with open(f'{hydra_path}/graphs/{dataname}/text_prompt_order.txt', 'r') as f: | |
text_prompt = f.readlines() | |
text_prompt = [x.strip() for x in text_prompt] | |
text_embs = model.encode(text_prompt) | |
cond_embs = torch.tensor(text_embs) | |
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)]) | |
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)]) | |
if dataname != 'eco': | |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)] | |
else: | |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)] | |
elif dataname == 'all': | |
for i, domain in enumerate(domains): | |
train_d = torch.load(f'{hydra_path}/graphs/{domain}/train.pt') | |
val_d = torch.load(f'{hydra_path}/graphs/{domain}/val.pt') | |
test_d = torch.load(f'{hydra_path}/graphs/{domain}/test.pt') | |
train_indices = torch.load(f'{hydra_path}/graphs/{domain}/train_indices.pt') | |
val_indices = torch.load(f'{hydra_path}/graphs/{domain}/val_indices.pt') | |
test_indices = torch.load(f'{hydra_path}/graphs/{domain}/test_indices.pt') | |
# text_prompt = torch.load(f'{hydra_path}/graphs/{domain}/text_prompt_order.pt') | |
with open(f'{hydra_path}/graphs/{domain}/text_prompt_order.txt', 'r') as f: | |
text_prompt = f.readlines() | |
text_prompt = [x.strip() for x in text_prompt] | |
print(domain, text_prompt[0]) | |
text_embs = model.encode(text_prompt) | |
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)]) | |
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)]) | |
test_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)]) | |
print(i, domain, len(train_data), len(val_data), len(test_data)) | |
print('Size of dataset', len(train_data), len(val_data), len(test_data)) | |
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True) | |
val_loader = DataLoader(val_data, batch_size = batch_size, shuffle=False) | |
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False) | |
return train_loader, val_loader, test_loader, train_data, val_data, test_data, text_embs.shape[1], torch.tensor(cond_embs) | |
def init_dataset(dataname, batch_size, hydra_path, condition, transition): | |
train_loader, val_loader, test_loader, train_data, val_data, test_data, cond_dims, cond_emb = load_dataset_cc(dataname, batch_size, hydra_path, condition) | |
n_nodes = node_counts(1000, train_loader, val_loader) | |
node_types = torch.tensor([1]) #No node types | |
edge_types = edge_counts(train_loader) | |
num_classes = len(node_types) | |
max_n_nodes = len(n_nodes) - 1 | |
nodes_dist = DistributionNodes(n_nodes) | |
print('Distribution of Number of Nodes:', n_nodes) | |
print('Distribution of Node Types:', node_types) | |
print('Distribution of Edge Types:', edge_types) | |
data_loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader} | |
return data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb | |
def node_counts(max_nodes_possible, train_loader, val_loader): | |
#Count the distribution of graph size | |
all_counts = torch.zeros(max_nodes_possible) | |
for loader in [train_loader, val_loader]: | |
for data in loader: | |
unique, counts = torch.unique(data.batch, return_counts=True) | |
for count in counts: | |
all_counts[count] += 1 | |
max_index = max(all_counts.nonzero()) | |
all_counts = all_counts[:max_index + 1] | |
all_counts = all_counts / all_counts.sum() | |
return all_counts | |
def node_counts_meta(max_nodes_possible, train_data, val_data, num_classes): | |
#Count the distribution of graph size | |
all_counts = [torch.zeros(max_nodes_possible) for _ in range(num_classes)] | |
for dataset in [train_data, val_data]: | |
for data in dataset: | |
all_counts[data.cond_type.item()][data.x.shape[0]] += 1 | |
for _ in range(num_classes): | |
tmp = all_counts[_].nonzero() | |
if len(tmp) == 0: | |
max_index = 1 | |
all_counts[_][0] = 1 | |
else: | |
max_index = max(tmp) | |
all_counts[_] = all_counts[_][:max_index + 1] | |
all_counts[_] = all_counts[_] / all_counts[_].sum() | |
return all_counts | |
def node_types(train_loader): | |
#Count the marginal distribution of node types | |
num_classes = None | |
for data in train_loader: | |
num_classes = data.x.shape[1] | |
break | |
counts = torch.zeros(num_classes) | |
for i, data in enumerate(train_loader): | |
counts += data.x.sum(dim=0) | |
counts = counts / counts.sum() | |
return counts | |
def edge_counts(train_loader): | |
#Count the marginal distribution of edge types | |
num_classes = None | |
for data in train_loader: | |
num_classes = data.edge_attr.shape[1] | |
break | |
d = torch.zeros(num_classes, dtype=torch.float) | |
for i, data in enumerate(train_loader): | |
unique, counts = torch.unique(data.batch, return_counts=True) | |
all_pairs = 0 | |
for count in counts: | |
all_pairs += count * (count - 1) | |
num_edges = data.edge_index.shape[1] | |
num_non_edges = all_pairs - num_edges | |
edge_types = data.edge_attr.sum(dim=0) | |
assert num_non_edges >= 0 | |
d[0] += num_non_edges | |
d[1:] += edge_types[1:] | |
d = d / d.sum() | |
return d | |
def edge_counts_meta(train_data, num_classes): | |
#Count the marginal distribution of edge types | |
num_edge_classes = None | |
for data in train_data: | |
num_edge_classes = data.edge_attr.shape[1] | |
break | |
d = [torch.ones(num_edge_classes, dtype=torch.float) for _ in range(num_classes)] | |
for i, data in enumerate(train_data): | |
n_nodes = data.x.shape[0] | |
all_pairs = n_nodes * (n_nodes - 1) | |
num_edges = data.edge_index.shape[1] | |
num_non_edges = all_pairs - num_edges | |
edge_types = data.edge_attr.sum(dim=0) | |
assert num_non_edges >= 0 | |
d[data.cond_type.item()][0] += num_non_edges | |
d[data.cond_type.item()][1:] += edge_types[1:] | |
for i, _ in enumerate(d): | |
d[i] = d[i] / d[i].sum() | |
d = torch.stack(d) | |
return d | |
def compute_input_output_dims(train_loader, extra_features): | |
example_batch = next(iter(train_loader)) | |
ex_dense, node_mask = to_dense(example_batch.x, example_batch.edge_index, example_batch.edge_attr, example_batch.batch) | |
example_data = {'X_t': ex_dense.X, 'E_t': ex_dense.E, 'y_t': example_batch['y'], 'node_mask': node_mask} | |
input_dims = {'X': example_batch['x'].size(1), | |
'E': example_batch['edge_attr'].size(1), | |
'y': example_batch['y'].size(1) + 1} # + 1 due to time conditioning | |
ex_extra_feat = extra_features(example_data) | |
input_dims['X'] += ex_extra_feat.X.size(-1) | |
input_dims['E'] += ex_extra_feat.E.size(-1) | |
input_dims['y'] += ex_extra_feat.y.size(-1) | |
output_dims = {'X': example_batch['x'].size(1), | |
'E': example_batch['edge_attr'].size(1), | |
'y': 0} | |
return input_dims, output_dims | |