LGGM-Text2Graph / demo_model.py
YuWang0103's picture
Update demo_model.py
bfd34aa verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from models.transformer_model import GraphTransformer
from diffusion.noise_schedule import DiscreteUniformTransition, PredefinedNoiseScheduleDiscrete
from diffusion import diffusion_utils
import utils
import networkx as nx
from sentence_transformers import SentenceTransformer
import pytorch_lightning as pl
from transformers import BertTokenizer, BertForSequenceClassification
class LGGMText2Graph_Demo(pl.LightningModule):
def __init__(self, cfg, input_dims, output_dims, cond_dims, cond_emb, \
nodes_dist, node_types, edge_types, extra_features, data_loaders):
super().__init__()
nodes_dist = nodes_dist
self.cfg = cfg
self.T = cfg.model.diffusion_steps
self.Xdim = input_dims['X']
self.Edim = input_dims['E']
self.ydim = input_dims['y']
self.Xdim_output = output_dims['X']
self.Edim_output = output_dims['E']
self.ydim_output = output_dims['y']
self.node_dist = nodes_dist
self.extra_features = extra_features
self.model = GraphTransformer(n_layers=cfg.model.n_layers,
input_dims=input_dims,
hidden_mlp_dims=cfg.model.hidden_mlp_dims,
hidden_dims=cfg.model.hidden_dims,
output_dims=output_dims,
cond_dims = cond_dims,
act_fn_in=nn.ReLU(),
act_fn_out=nn.ReLU()).to(self.device)
self.noise_schedule = PredefinedNoiseScheduleDiscrete(cfg.model.diffusion_noise_schedule,
timesteps=cfg.model.diffusion_steps).to(self.device)
self.transition_model = DiscreteUniformTransition(x_classes=self.Xdim_output, e_classes=self.Edim_output,
y_classes=self.ydim_output)
x_limit = torch.ones(self.Xdim_output) / self.Xdim_output
e_limit = torch.ones(self.Edim_output) / self.Edim_output
y_limit = torch.ones(self.ydim_output) / self.ydim_output
self.limit_dist = utils.PlaceHolder(X=x_limit, E=e_limit, y=y_limit)
def generate_basic(self, text, num_nodes) -> None:
print(num_nodes)
prompt_emb = torch.tensor(self.text_encoder.encode([text])).to(self.device)
samples = self.sample_batch(5, cond_emb = prompt_emb, num_nodes = num_nodes)
nx_graphs = []
for graph in samples:
node_types, edge_types = graph
A = edge_types.bool().cpu().numpy()
nx_graph = nx.from_numpy_array(A)
nx_graphs.append(nx_graph)
return nx_graphs
def generate_pretrained(self, text, num_nodes) -> None:
encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
encoded_input = {key: val.to(self.text_encoder.device) for key, val in encoded_input.items()}
# Get the model output
with torch.no_grad():
prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0]
samples = self.sample_batch(3, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes)
nx_graphs = []
for graph in samples:
node_types, edge_types = graph
A = edge_types.bool().cpu().numpy()
nx_graph = nx.from_numpy_array(A)
nx_graphs.append(nx_graph)
return nx_graphs
def init_prompt_encoder_basic(self):
self.text_encoder = SentenceTransformer("all-MiniLM-L6-v2")
def init_prompt_encoder_pretrained(self):
model_name = f"./checkpoint-900" # or "bert-base-uncased" if starting from the base model
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.text_encoder = BertForSequenceClassification.from_pretrained(model_name, num_labels=8, output_hidden_states=True, device_map = 'cpu')
@torch.no_grad()
def sample_batch(self, batch_size: int, cond_emb = None, num_nodes = None):
"""
:param batch_id: int
:param batch_size: int
:param num_nodes: int, <int>tensor (batch_size) (optional) for specifying number of nodes
:param save_final: int: number of predictions to save to file
:param keep_chain: int: number of chains to save to file
:param keep_chain_steps: number of timesteps to save for each chain
:return: molecule_list. Each element of this list is a tuple (atom_types, charges, positions)
"""
if num_nodes is None:
n_nodes = self.node_dist.sample_n(batch_size, self.device)
elif type(num_nodes) == int:
n_nodes = num_nodes * torch.ones(batch_size, device=self.device, dtype=torch.int)
n_max = torch.max(n_nodes).item()
# Build the masks
arange = torch.arange(n_max, device=self.device).unsqueeze(0).expand(batch_size, -1)
node_mask = arange < n_nodes.unsqueeze(1)
# Sample noise -- z has size (n_samples, n_nodes, n_features)
z_T = diffusion_utils.sample_discrete_feature_noise(limit_dist=self.limit_dist, node_mask=node_mask, transition=self.cfg.model.transition)
X, E, y = z_T.X, z_T.E, z_T.y
# Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
for s_int in tqdm(reversed(range(0, self.T))):
s_array = s_int * torch.ones((batch_size, 1)).type_as(y)
t_array = s_array + 1
s_norm = s_array / self.T
t_norm = t_array / self.T
# Sample z_s
sampled_s = self.sample_p_zs_given_zt(s_norm, t_norm, X, E, y, node_mask, cond_emb)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
# Sample
sampled_s = sampled_s.mask(node_mask, collapse=True)
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
graph_list = []
for i in range(batch_size):
n = n_nodes[i]
node_types = X[i, :n].cpu()
edge_types = E[i, :n, :n].cpu()
graph_list.append([node_types, edge_types])
return graph_list
def sample_p_zs_given_zt(self, s, t, X_t, E_t, y_t, node_mask, cond_emb):
"""Samples from zs ~ p(zs | zt). Only used during sampling.
if last_step, return the graph prediction as well"""
bs, n, dxs = X_t.shape
beta_t = self.noise_schedule(t_normalized=t) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t)
# Retrieve transitions matrix
Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, self.device)
Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, self.device)
Qt = self.transition_model.get_Qt(beta_t, self.device)
noisy_data = {'X_t': X_t, 'E_t': E_t, 'y_t': y_t, 't': t, 'node_mask': node_mask, 'cond_emb': cond_emb.repeat(X_t.shape[0], 1)}
extra_data = self.compute_extra_data(noisy_data)
pred = self.forward(noisy_data, extra_data, node_mask)
# Normalize predictions
pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0
pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0
p_s_and_t_given_0_X = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=X_t,
Qt=Qt.X,
Qsb=Qsb.X,
Qtb=Qtb.X)
p_s_and_t_given_0_E = diffusion_utils.compute_batched_over0_posterior_distribution(X_t=E_t,
Qt=Qt.E,
Qsb=Qsb.E,
Qtb=Qtb.E)
# Dim of these two tensors: bs, N, d0, d_t-1
weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X # bs, n, d0, d_t-1
unnormalized_prob_X = weighted_X.sum(dim=2) # bs, n, d_t-1
unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5
prob_X = unnormalized_prob_X / torch.sum(unnormalized_prob_X, dim=-1, keepdim=True) # bs, n, d_t-1
pred_E = pred_E.reshape((bs, -1, pred_E.shape[-1]))
weighted_E = pred_E.unsqueeze(-1) * p_s_and_t_given_0_E # bs, N, d0, d_t-1
unnormalized_prob_E = weighted_E.sum(dim=-2)
unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5
prob_E = unnormalized_prob_E / torch.sum(unnormalized_prob_E, dim=-1, keepdim=True)
prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1])
assert ((prob_X.sum(dim=-1) - 1).abs() < 1e-4).all()
assert ((prob_E.sum(dim=-1) - 1).abs() < 1e-4).all()
sampled_s = diffusion_utils.sample_discrete_features(prob_X, prob_E, node_mask=node_mask)
X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).float()
E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).float()
assert (E_s == torch.transpose(E_s, 1, 2)).all()
assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape)
out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=torch.zeros(y_t.shape[0], 0))
return out_one_hot.mask(node_mask).type_as(y_t)
def compute_extra_data(self, noisy_data):
""" At every training step (after adding noise) and step in sampling, compute extra information and append to
the network input. """
extra_features = self.extra_features(noisy_data)
# print(extra_features.X.shape, extra_features.E.shape, extra_features.y.shape)
extra_X = extra_features.X
extra_E = extra_features.E
extra_y = extra_features.y
t = noisy_data['t']
extra_y = torch.cat((extra_y, t), dim=1)
return utils.PlaceHolder(X=extra_X, E=extra_E, y=extra_y)
def forward(self, noisy_data, extra_data, node_mask):
# print(noisy_data['cond_emb'].sum())
B = noisy_data['cond_emb'].unsqueeze(1).unsqueeze(2).expand(-1, noisy_data['X_t'].shape[1], noisy_data['X_t'].shape[1], -1).to(self.device)
A = noisy_data['cond_emb'].unsqueeze(1).expand(-1, noisy_data['X_t'].shape[1], -1).to(self.device)
X = torch.cat((noisy_data['X_t'], extra_data.X, A), dim=2).float()
E = torch.cat((noisy_data['E_t'], extra_data.E, B), dim=3).float()
y = torch.hstack((noisy_data['y_t'], extra_data.y)).float()
return self.model(X, E, y, node_mask)