from omegaconf import OmegaConf import gradio as gr from dataset import init_dataset, compute_input_output_dims from extra_features import ExtraFeatures from demo_model import LGGMText2Graph_Demo from analysis.spectre_utils import CrossDomainSamplingMetrics import networkx as nx import numpy as np import matplotlib.pyplot as plt cfg = OmegaConf.load('./config.yaml') hydra_path = '.' data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb = init_dataset(cfg.dataset.name, cfg.train.batch_size, hydra_path, cfg.general.condition, cfg.model.transition) extra_features = ExtraFeatures(cfg.model.extra_features, max_n_nodes) input_dims, output_dims = compute_input_output_dims(data_loaders['train'], extra_features) sampling_metrics = CrossDomainSamplingMetrics(data_loaders) model = LGGMText2Graph_Demo.load_from_checkpoint('last-v1.ckpt') model.init_prompt_encoder() def calculate_average_degree(graph): num_nodes = graph.number_of_nodes() num_edges = graph.number_of_edges() return (2 * num_edges) / num_nodes if num_nodes > 0 else 0 def predict(text, num_nodes = None): # Assuming model.generate and other processes are defined as before graphs = model.generate(text, int(num_nodes)) ccs = [] degs = [] images = [] for g in graphs: ccs.append(nx.average_clustering(g)) degs.append(calculate_average_degree(g)) fig, ax = plt.subplots() nx.draw(g, ax=ax) fig.canvas.draw() image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) plt.close(fig) images.append(image) return images[0], images[1], images[2], images[3], images[4], ccs[0], ccs[1], ccs[2], ccs[3], ccs[4], degs[0], degs[1], degs[2], degs[3], degs[4] def clear(input_text): return None, None with gr.Blocks() as demo: gr.Markdown("## Text2Graph Generation Demo") with gr.Row(): with gr.Column(): input_text = gr.Textbox(label="Input your text prompt here", placeholder="Type here...") with gr.Column(): input_num = gr.Slider(5, 200, value=10, label="Count", info="Number of nodes in the graph to be generated") with gr.Column(): gr.Markdown("### Suggested Prompts") gr.Markdown("1. Create a complex network with high clustering coefficient.\n2. Create a graph with extremely low number of triangles.") with gr.Row() as output_row: output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(5)] with gr.Row(): output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(5)] with gr.Row(): output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(5)] with gr.Row(): submit_button = gr.Button("Submit") clear_button = gr.Button("Clear") # Change function is linked to the submit button submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg) # Clear function resets the text input and clears the outputs clear_button.click(fn=clear, inputs=input_text, outputs=output_images + output_texts_cc + output_texts_deg) demo.launch()