LGGM-Text2Graph / app.py
YuWang0103's picture
Update app.py
0a172b6 verified
from omegaconf import OmegaConf
import gradio as gr
from dataset import init_dataset, compute_input_output_dims
from extra_features import ExtraFeatures
from demo_model import LGGMText2Graph_Demo
from analysis.spectre_utils import CrossDomainSamplingMetrics
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import torch
cfg = OmegaConf.load('./config.yaml')
hydra_path = '.'
data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb = init_dataset(cfg.dataset.name, cfg.train.batch_size, hydra_path, cfg.general.condition, cfg.model.transition)
extra_features = ExtraFeatures(cfg.model.extra_features, max_n_nodes)
input_dims, output_dims = compute_input_output_dims(data_loaders['train'], extra_features)
sampling_metrics = CrossDomainSamplingMetrics(data_loaders)
# model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device('cpu'))
model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device("cpu"))
model.init_prompt_encoder_pretrained()
def calculate_average_degree(graph):
num_nodes = graph.number_of_nodes()
num_edges = graph.number_of_edges()
return (2 * num_edges) / num_nodes if num_nodes > 0 else 0
def predict(text, num_nodes = None):
# Assuming model.generate and other processes are defined as before
graphs = model.generate_pretrained(text, int(num_nodes))
ccs = []
degs = []
images = []
for g in graphs:
ccs.append(nx.average_clustering(g))
degs.append(calculate_average_degree(g))
fig, ax = plt.subplots()
nx.draw(g, ax=ax)
fig.canvas.draw()
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
images.append(image)
avg_deg = np.mean(degs)
avg_cc = np.mean(ccs)
return images[0], images[1], images[2], ccs[0], ccs[1], ccs[2], degs[0], degs[1], degs[2], avg_cc, avg_deg
def clear(input_text):
return None, None, None, None, None, None, None, None, None, None, None
with gr.Blocks() as demo:
gr.Markdown("## Text2Graph Generation Demo")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input your text prompt here", placeholder="Type here...")
with gr.Column():
input_num = gr.Slider(5, 100, value=25, step = 1, label="Count", info="Number of nodes in the graph to be generated")
with gr.Column():
gr.Markdown("### Suggested Prompts")
gr.Markdown("1. Create a complex network with high clustering coefficient.\n2. Create a graph with extremely low number of triangles.\n 3. Please give me a Power Network with extremely low number of triangles but with medium level of average degree.")
with gr.Row() as output_row:
output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(3)]
with gr.Row():
output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(3)]
with gr.Row():
output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(3)]
with gr.Row():
avg_cc_text = gr.Textbox(label="Average Clustering Coefficient")
avg_deg_text = gr.Textbox(label="Average Degree")
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
# Change function is linked to the submit button
submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
input_text.submit(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
# Clear function resets the text input and clears the outputs
clear_button.click(fn=clear, inputs=[input_text], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
demo.launch()