LGGM-Text2Graph / dataset.py
YuWang0103's picture
Update dataset.py
b31c308 verified
import torch
from distributions import DistributionNodes
from utils import to_dense
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import remove_self_loops, to_undirected
import os
from sentence_transformers import SentenceTransformer
import random
def arrange_data(adj_matrix, cond_emb, ind):
n_nodes = adj_matrix.shape[0]
edge_index = adj_matrix.nonzero().t()
edge_attr = torch.tensor([[0, 1] for _ in range(edge_index.shape[1])])
edge_index, edge_attr = to_undirected(edge_index, edge_attr, n_nodes, reduce = 'mean')
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
x = torch.ones((n_nodes, 1))
y = torch.empty(1, 0)
cond_emb = torch.tensor(cond_emb).unsqueeze(0)
return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, prompt_id = torch.tensor(ind), cond_emb = cond_emb)
def load_dataset_cc(dataname, batch_size, hydra_path, condition):
domains = ['cc_high', 'cc_medium', 'cc_low']
model = SentenceTransformer("all-MiniLM-L6-v2")
cond_embs = model.encode(condition)
train_data, val_data, test_data = [], [], []
if dataname in domains: #only for test
train_d = torch.load(f'{hydra_path}/graphs/{dataname}/train.pt')
val_d = torch.load(f'{hydra_path}/graphs/{dataname}/val.pt')
test_d = torch.load(f'{hydra_path}/graphs/{dataname}/test.pt')
train_indices = torch.load(f'{hydra_path}/graphs/{dataname}/train_indices.pt')
val_indices = torch.load(f'{hydra_path}/graphs/{dataname}/val_indices.pt')
test_indices = torch.load(f'{hydra_path}/graphs/{dataname}/test_indices.pt')
with open(f'{hydra_path}/graphs/{dataname}/text_prompt_order.txt', 'r') as f:
text_prompt = f.readlines()
text_prompt = [x.strip() for x in text_prompt]
# text_prompt = ['1111111shgowhgo234o234']*10000
print(text_prompt[0])
text_embs = model.encode(text_prompt)
cond_embs = torch.tensor(text_embs)
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)])
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)])
if dataname != 'eco':
# test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)]
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)]
else:
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
print('Size of dataset', len(train_data), len(val_data), len(test_data))
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size = batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, test_loader, train_data, val_data, test_data, text_embs.shape[1], torch.tensor(cond_embs)
def load_dataset_deg(dataname, batch_size, hydra_path, condition):
domains = ['deg_high', 'deg_medium', 'deg_low']
model = SentenceTransformer("all-MiniLM-L6-v2")
cond_embs = model.encode(condition)
for domain in domains:
if not os.path.exists(f'{hydra_path}/graphs/{domain}/train.pt'):
data = torch.load(f'{hydra_path}/graphs/{domain}/{domain}.pt')
#fix seed
torch.manual_seed(0)
#random permute and split
n = len(data)
indices = torch.randperm(n)
if domain == 'eco':
train_indices = indices[:4].repeat(50)
val_indices = indices[4:5].repeat(50)
test_indices = indices[5:]
else:
train_indices = indices[:int(0.7 * n)]
val_indices = indices[int(0.7 * n):int(0.8 * n)]
test_indices = indices[int(0.8 * n):]
train_data = [data[_] for _ in train_indices]
val_data = [data[_] for _ in val_indices]
test_data = [data[_] for _ in test_indices]
torch.save(train_indices, f'{hydra_path}/graphs/{domain}/train_indices.pt')
torch.save(val_indices, f'{hydra_path}/graphs/{domain}/val_indices.pt')
torch.save(test_indices, f'{hydra_path}/graphs/{domain}/test_indices.pt')
torch.save(train_data, f'{hydra_path}/graphs/{domain}/train.pt')
torch.save(val_data, f'{hydra_path}/graphs/{domain}/val.pt')
torch.save(test_data, f'{hydra_path}/graphs/{domain}/test.pt')
train_data, val_data, test_data = [], [], []
if dataname in domains: #only for test
train_d = torch.load(f'{hydra_path}/graphs/{dataname}/train.pt')
val_d = torch.load(f'{hydra_path}/graphs/{dataname}/val.pt')
test_d = torch.load(f'{hydra_path}/graphs/{dataname}/test.pt')
train_indices = torch.load(f'{hydra_path}/graphs/{dataname}/train_indices.pt')
val_indices = torch.load(f'{hydra_path}/graphs/{dataname}/val_indices.pt')
test_indices = torch.load(f'{hydra_path}/graphs/{dataname}/test_indices.pt')
with open(f'{hydra_path}/graphs/{dataname}/text_prompt_order.txt', 'r') as f:
text_prompt = f.readlines()
text_prompt = [x.strip() for x in text_prompt]
text_embs = model.encode(text_prompt)
cond_embs = torch.tensor(text_embs)
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)])
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)])
if dataname != 'eco':
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)]
else:
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
elif dataname == 'all':
for i, domain in enumerate(domains):
train_d = torch.load(f'{hydra_path}/graphs/{domain}/train.pt')
val_d = torch.load(f'{hydra_path}/graphs/{domain}/val.pt')
test_d = torch.load(f'{hydra_path}/graphs/{domain}/test.pt')
train_indices = torch.load(f'{hydra_path}/graphs/{domain}/train_indices.pt')
val_indices = torch.load(f'{hydra_path}/graphs/{domain}/val_indices.pt')
test_indices = torch.load(f'{hydra_path}/graphs/{domain}/test_indices.pt')
# text_prompt = torch.load(f'{hydra_path}/graphs/{domain}/text_prompt_order.pt')
with open(f'{hydra_path}/graphs/{domain}/text_prompt_order.txt', 'r') as f:
text_prompt = f.readlines()
text_prompt = [x.strip() for x in text_prompt]
print(domain, text_prompt[0])
text_embs = model.encode(text_prompt)
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)])
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)])
test_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)])
print(i, domain, len(train_data), len(val_data), len(test_data))
print('Size of dataset', len(train_data), len(val_data), len(test_data))
train_loader = DataLoader(train_data, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size = batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, test_loader, train_data, val_data, test_data, text_embs.shape[1], torch.tensor(cond_embs)
def init_dataset(dataname, batch_size, hydra_path, condition, transition):
train_loader, val_loader, test_loader, train_data, val_data, test_data, cond_dims, cond_emb = load_dataset_cc(dataname, batch_size, hydra_path, condition)
n_nodes = node_counts(1000, train_loader, val_loader)
node_types = torch.tensor([1]) #No node types
edge_types = edge_counts(train_loader)
num_classes = len(node_types)
max_n_nodes = len(n_nodes) - 1
nodes_dist = DistributionNodes(n_nodes)
print('Distribution of Number of Nodes:', n_nodes)
print('Distribution of Node Types:', node_types)
print('Distribution of Edge Types:', edge_types)
data_loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}
return data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb
def node_counts(max_nodes_possible, train_loader, val_loader):
#Count the distribution of graph size
all_counts = torch.zeros(max_nodes_possible)
for loader in [train_loader, val_loader]:
for data in loader:
unique, counts = torch.unique(data.batch, return_counts=True)
for count in counts:
all_counts[count] += 1
max_index = max(all_counts.nonzero())
all_counts = all_counts[:max_index + 1]
all_counts = all_counts / all_counts.sum()
return all_counts
def node_counts_meta(max_nodes_possible, train_data, val_data, num_classes):
#Count the distribution of graph size
all_counts = [torch.zeros(max_nodes_possible) for _ in range(num_classes)]
for dataset in [train_data, val_data]:
for data in dataset:
all_counts[data.cond_type.item()][data.x.shape[0]] += 1
for _ in range(num_classes):
tmp = all_counts[_].nonzero()
if len(tmp) == 0:
max_index = 1
all_counts[_][0] = 1
else:
max_index = max(tmp)
all_counts[_] = all_counts[_][:max_index + 1]
all_counts[_] = all_counts[_] / all_counts[_].sum()
return all_counts
def node_types(train_loader):
#Count the marginal distribution of node types
num_classes = None
for data in train_loader:
num_classes = data.x.shape[1]
break
counts = torch.zeros(num_classes)
for i, data in enumerate(train_loader):
counts += data.x.sum(dim=0)
counts = counts / counts.sum()
return counts
def edge_counts(train_loader):
#Count the marginal distribution of edge types
num_classes = None
for data in train_loader:
num_classes = data.edge_attr.shape[1]
break
d = torch.zeros(num_classes, dtype=torch.float)
for i, data in enumerate(train_loader):
unique, counts = torch.unique(data.batch, return_counts=True)
all_pairs = 0
for count in counts:
all_pairs += count * (count - 1)
num_edges = data.edge_index.shape[1]
num_non_edges = all_pairs - num_edges
edge_types = data.edge_attr.sum(dim=0)
assert num_non_edges >= 0
d[0] += num_non_edges
d[1:] += edge_types[1:]
d = d / d.sum()
return d
def edge_counts_meta(train_data, num_classes):
#Count the marginal distribution of edge types
num_edge_classes = None
for data in train_data:
num_edge_classes = data.edge_attr.shape[1]
break
d = [torch.ones(num_edge_classes, dtype=torch.float) for _ in range(num_classes)]
for i, data in enumerate(train_data):
n_nodes = data.x.shape[0]
all_pairs = n_nodes * (n_nodes - 1)
num_edges = data.edge_index.shape[1]
num_non_edges = all_pairs - num_edges
edge_types = data.edge_attr.sum(dim=0)
assert num_non_edges >= 0
d[data.cond_type.item()][0] += num_non_edges
d[data.cond_type.item()][1:] += edge_types[1:]
for i, _ in enumerate(d):
d[i] = d[i] / d[i].sum()
d = torch.stack(d)
return d
def compute_input_output_dims(train_loader, extra_features):
example_batch = next(iter(train_loader))
ex_dense, node_mask = to_dense(example_batch.x, example_batch.edge_index, example_batch.edge_attr, example_batch.batch)
example_data = {'X_t': ex_dense.X, 'E_t': ex_dense.E, 'y_t': example_batch['y'], 'node_mask': node_mask}
input_dims = {'X': example_batch['x'].size(1),
'E': example_batch['edge_attr'].size(1),
'y': example_batch['y'].size(1) + 1} # + 1 due to time conditioning
ex_extra_feat = extra_features(example_data)
input_dims['X'] += ex_extra_feat.X.size(-1)
input_dims['E'] += ex_extra_feat.E.size(-1)
input_dims['y'] += ex_extra_feat.y.size(-1)
output_dims = {'X': example_batch['x'].size(1),
'E': example_batch['edge_attr'].size(1),
'y': 0}
return input_dims, output_dims