import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from models.transformer_model import GraphTransformer from diffusion.noise_schedule import DiscreteUniformTransition, PredefinedNoiseScheduleDiscrete from diffusion import diffusion_utils import utils import networkx as nx from sentence_transformers import SentenceTransformer import pytorch_lightning as pl from transformers import BertTokenizer, BertForSequenceClassification class LGGMText2Graph_Demo(pl.LightningModule): def __init__(self, cfg, input_dims, output_dims, cond_dims, cond_emb, \ nodes_dist, node_types, edge_types, extra_features, data_loaders): super().__init__() nodes_dist = nodes_dist self.cfg = cfg self.T = cfg.model.diffusion_steps self.Xdim = input_dims['X'] self.Edim = input_dims['E'] self.ydim = input_dims['y'] self.Xdim_output = output_dims['X'] self.Edim_output = output_dims['E'] self.ydim_output = output_dims['y'] self.node_dist = nodes_dist self.extra_features = extra_features self.model = GraphTransformer(n_layers=cfg.model.n_layers, input_dims=input_dims, hidden_mlp_dims=cfg.model.hidden_mlp_dims, hidden_dims=cfg.model.hidden_dims, output_dims=output_dims, cond_dims = cond_dims, act_fn_in=nn.ReLU(), act_fn_out=nn.ReLU()).to(self.device) self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule, timesteps=cfg.model.diffusion_steps).to(self.device) self.transition_model = DiscreteUniformTransition(x_classes=self.Xdim_output, e_classes=self.Edim_output, y_classes=self.ydim_output) x_limit = torch.ones(self.Xdim_output) / self.Xdim_output e_limit = torch.ones(self.Edim_output) / self.Edim_output y_limit = torch.ones(self.ydim_output) / self.ydim_output self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit) def generate_basic(self, text, num_nodes) -> None: print(num_nodes) prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device) samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes) nx_graphs = [] for graph in samples: node_types, edge_types = graph A = edge_types.bool().cpu().numpy() nx_graph = nx.from_numpy_array(A) nx_graphs.append(nx_graph) return nx_graphs def generate_pretrained(self, text, num_nodes) -> None: encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512) encoded_input = {key: val.to(self.text_encoder.device) for key, val in encoded_input.items()} # Get the model output with torch.no_grad(): prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0] samples = self.sample_batch(3, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes) nx_graphs = [] for graph in samples: node_types, edge_types = graph A = edge_types.bool().cpu().numpy() nx_graph = nx.from_numpy_array(A) nx_graphs.append(nx_graph) return nx_graphs def init_prompt_encoder_basic(self): self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2") def init_prompt_encoder_pretrained(self): model_name = f"./checkpoint-900" # or "bert-base-uncased" if starting from the base model self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') self.text_encoder = BertForSequenceClassification.from_pretrained(model_name, num_labels=8, output_hidden_states=True, device_map = 'cpu') @torch.no_grad() def sample_batch(self, batch_size: int, cond_emb = None, num_nodes = None): """ :param batch_id: int :param batch_size: int :param num_nodes: int, tensor (batch_size) (optional) for specifying number of nodes :param save_final: int: number of predictions to save to file :param keep_chain: int: number of chains to save to file :param keep_chain_steps: number of timesteps to save for each chain :return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions) """ if num_nodes is None: n_nodes = self.node_dist.sample_n(batch_size, self.device) elif type(num_nodes) == int: n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int) n_max = torch.max(n_nodes).item() # Build the masks arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1) node_mask = arange < n_nodes.unsqueeze(1) # Sample noise -- z has size (n_samples, n_nodes, n_features) z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask, transition=self.cfg.model.transition) X, E, y = z_T.X, z_T.E, z_T.y # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. for s_int in tqdm(reversed(range(0, self.T))): s_array = s_int * torch.ones((batch_size, 1)).type_as(y) t_array = s_array + 1 s_norm = s_array / self.T t_norm = t_array / self.T # Sample z_s sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask, cond_emb) X, E, y = sampled_s.X, sampled_s.E, sampled_s.y # Sample sampled_s = sampled_s.mask(node_mask, collapse=True) X, E, y = sampled_s.X, sampled_s.E, sampled_s.y graph_list = [] for i in range(batch_size): n = n_nodes[i] node_types = X[i, :n].cpu() edge_types = E[i, :n, :n].cpu() graph_list.append([node_types, edge_types]) return graph_list def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask, cond_emb): """Samples from zs ~ p(zs | zt). Only used during sampling. if last_step, return the graph prediction as well""" bs, n, dxs = X_t.shape beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) # Retrieve transitions matrix Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device) Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device) Qt = self.transition_model.get_Qt(beta_t, self.device) noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask, 'cond_emb': cond_emb.repeat(X_t.shape[0], 1)} extra_data = self.compute_extra_data(noisy_data) pred = self.forward(noisy_data, extra_data, node_mask) # Normalize predictions pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X) p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t, Qt=Qt.E, Qsb=Qsb.E, Qtb=Qtb.E) # Dim of these two tensors: bs, N, d0, d_t-1 weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X # bs, n, d0, d_t-1 unnormalized_prob_X = weighted_X.sum(dim=2) # bs, n, d_t-1 unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True) # bs, n, d_t-1 pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1])) weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E # bs, N, d0, d_t-1 unnormalized_prob_E = weighted_E.sum(dim=-2) unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True) prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all() assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all() sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask) X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float() E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float() assert (E_s == torch.transpose(E_s, 1, 2)).all() assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0)) return out_one_hot.mask(node_mask).type_as(y_t) def compute_extra_data(self, noisy_data): """ At every training step (after adding noise) and step in sampling, compute extra information and append to the network input. """ extra_features = self.extra_features(noisy_data) # print(extra_features.X.shape, extra_features.E.shape, extra_features.y.shape) extra_X = extra_features.X extra_E = extra_features.E extra_y = extra_features.y t = noisy_data['t'] extra_y = torch.cat((extra_y, t), dim=1) return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y) def forward(self, noisy_data, extra_data, node_mask): # print(noisy_data['cond_emb'].sum()) B = noisy_data['cond_emb'].unsqueeze(1).unsqueeze(2).expand(-1, noisy_data['X_t'].shape[1], noisy_data['X_t'].shape[1], -1).to(self.device) A = noisy_data['cond_emb'].unsqueeze(1).expand(-1, noisy_data['X_t'].shape[1], -1).to(self.device) X = torch.cat((noisy_data['X_t'], extra_data.X, A), dim=2).float() E = torch.cat((noisy_data['E_t'], extra_data.E, B), dim=3).float() y = torch.hstack((noisy_data['y_t'], extra_data.y)).float() return self.model(X, E, y, node_mask)