Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from models.transformer_model import GraphTransformer | |
from diffusion.noise_schedule import DiscreteUniformTransition, PredefinedNoiseScheduleDiscrete | |
from diffusion import diffusion_utils | |
import utils | |
import networkx as nx | |
from sentence_transformers import SentenceTransformer | |
import pytorch_lightning as pl | |
from transformers import BertTokenizer, BertForSequenceClassification | |
class LGGMText2Graph_Demo(pl.LightningModule): | |
def __init__(self, cfg, input_dims, output_dims, cond_dims, cond_emb, \ | |
nodes_dist, node_types, edge_types, extra_features, data_loaders): | |
super().__init__() | |
nodes_dist = nodes_dist | |
self.cfg = cfg | |
self.T = cfg.model.diffusion_steps | |
self.Xdim = input_dims['X'] | |
self.Edim = input_dims['E'] | |
self.ydim = input_dims['y'] | |
self.Xdim_output = output_dims['X'] | |
self.Edim_output = output_dims['E'] | |
self.ydim_output = output_dims['y'] | |
self.node_dist = nodes_dist | |
self.extra_features = extra_features | |
self.model = GraphTransformer(n_layers=cfg.model.n_layers, | |
input_dims=input_dims, | |
hidden_mlp_dims=cfg.model.hidden_mlp_dims, | |
hidden_dims=cfg.model.hidden_dims, | |
output_dims=output_dims, | |
cond_dims = cond_dims, | |
act_fn_in=nn.ReLU(), | |
act_fn_out=nn.ReLU()).to(self.device) | |
self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule, | |
timesteps=cfg.model.diffusion_steps).to(self.device) | |
self.transition_model = DiscreteUniformTransition(x_classes=self.Xdim_output, e_classes=self.Edim_output, | |
y_classes=self.ydim_output) | |
x_limit = torch.ones(self.Xdim_output) / self.Xdim_output | |
e_limit = torch.ones(self.Edim_output) / self.Edim_output | |
y_limit = torch.ones(self.ydim_output) / self.ydim_output | |
self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit) | |
def generate_basic(self, text, num_nodes) -> None: | |
print(num_nodes) | |
prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device) | |
samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes) | |
nx_graphs = [] | |
for graph in samples: | |
node_types, edge_types = graph | |
A = edge_types.bool().cpu().numpy() | |
nx_graph = nx.from_numpy_array(A) | |
nx_graphs.append(nx_graph) | |
return nx_graphs | |
def generate_pretrained(self, text, num_nodes) -> None: | |
encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512) | |
encoded_input = {key: val.to(self.text_encoder.device) for key, val in encoded_input.items()} | |
# Get the model output | |
with torch.no_grad(): | |
prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0] | |
samples = self.sample_batch(3, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes) | |
nx_graphs = [] | |
for graph in samples: | |
node_types, edge_types = graph | |
A = edge_types.bool().cpu().numpy() | |
nx_graph = nx.from_numpy_array(A) | |
nx_graphs.append(nx_graph) | |
return nx_graphs | |
def init_prompt_encoder_basic(self): | |
self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2") | |
def init_prompt_encoder_pretrained(self): | |
model_name = f"./checkpoint-900" # or "bert-base-uncased" if starting from the base model | |
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
self.text_encoder = BertForSequenceClassification.from_pretrained(model_name, num_labels=8, output_hidden_states=True, device_map = 'cpu') | |
def sample_batch(self, batch_size: int, cond_emb = None, num_nodes = None): | |
""" | |
:param batch_id: int | |
:param batch_size: int | |
:param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes | |
:param save_final: int: number of predictions to save to file | |
:param keep_chain: int: number of chains to save to file | |
:param keep_chain_steps: number of timesteps to save for each chain | |
:return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions) | |
""" | |
if num_nodes is None: | |
n_nodes = self.node_dist.sample_n(batch_size, self.device) | |
elif type(num_nodes) == int: | |
n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int) | |
n_max = torch.max(n_nodes).item() | |
# Build the masks | |
arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1) | |
node_mask = arange < n_nodes.unsqueeze(1) | |
# Sample noise -- z has size (n_samples, n_nodes, n_features) | |
z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask, transition=self.cfg.model.transition) | |
X, E, y = z_T.X, z_T.E, z_T.y | |
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. | |
for s_int in tqdm(reversed(range(0, self.T))): | |
s_array = s_int * torch.ones((batch_size, 1)).type_as(y) | |
t_array = s_array + 1 | |
s_norm = s_array / self.T | |
t_norm = t_array / self.T | |
# Sample z_s | |
sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask, cond_emb) | |
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
# Sample | |
sampled_s = sampled_s.mask(node_mask, collapse=True) | |
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y | |
graph_list = [] | |
for i in range(batch_size): | |
n = n_nodes[i] | |
node_types = X[i, :n].cpu() | |
edge_types = E[i, :n, :n].cpu() | |
graph_list.append([node_types, edge_types]) | |
return graph_list | |
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask, cond_emb): | |
"""Samples from zs ~ p(zs | zt). Only used during sampling. | |
if last_step, return the graph prediction as well""" | |
bs, n, dxs = X_t.shape | |
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) | |
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) | |
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) | |
# Retrieve transitions matrix | |
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) | |
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device) | |
Qt = self.transition_model.get_Qt(beta_t, self.device) | |
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask, 'cond_emb': cond_emb.repeat(X_t.shape[0], 1)} | |
extra_data = self.compute_extra_data(noisy_data) | |
pred = self.forward(noisy_data, extra_data, node_mask) | |
# Normalize predictions | |
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 | |
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 | |
p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t, | |
Qt=Qt.X, | |
Qsb=Qsb.X, | |
Qtb=Qtb.X) | |
p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t, | |
Qt=Qt.E, | |
Qsb=Qsb.E, | |
Qtb=Qtb.E) | |
# Dim of these two tensors: bs, N, d0, d_t-1 | |
weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X # bs, n, d0, d_t-1 | |
unnormalized_prob_X = weighted_X.sum(dim=2) # bs, n, d_t-1 | |
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 | |
prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True) # bs, n, d_t-1 | |
pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1])) | |
weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E # bs, N, d0, d_t-1 | |
unnormalized_prob_E = weighted_E.sum(dim=-2) | |
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 | |
prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True) | |
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) | |
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() | |
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() | |
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask) | |
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() | |
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() | |
assert (E_s == torch.transpose(E_s, 1, 2)).all() | |
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) | |
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0)) | |
return out_one_hot.mask(node_mask).type_as(y_t) | |
def compute_extra_data(self, noisy_data): | |
""" At every training step (after adding noise) and step in sampling, compute extra information and append to | |
the network input. """ | |
extra_features = self.extra_features(noisy_data) | |
# print(extra_features.X.shape, extra_features.E.shape, extra_features.y.shape) | |
extra_X = extra_features.X | |
extra_E = extra_features.E | |
extra_y = extra_features.y | |
t = noisy_data['t'] | |
extra_y = torch.cat((extra_y, t), dim=1) | |
return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y) | |
def forward(self, noisy_data, extra_data, node_mask): | |
# print(noisy_data['cond_emb'].sum()) | |
B = noisy_data['cond_emb'].unsqueeze(1).unsqueeze(2).expand(-1, noisy_data['X_t'].shape[1], noisy_data['X_t'].shape[1], -1).to(self.device) | |
A = noisy_data['cond_emb'].unsqueeze(1).expand(-1, noisy_data['X_t'].shape[1], -1).to(self.device) | |
X = torch.cat((noisy_data['X_t'], extra_data.X, A), dim=2).float() | |
E = torch.cat((noisy_data['E_t'], extra_data.E, B), dim=3).float() | |
y = torch.hstack((noisy_data['y_t'], extra_data.y)).float() | |
return self.model(X, E, y, node_mask) | |