Spaces:
Runtime error
Runtime error
File size: 4,045 Bytes
6b59850 38ed701 6b59850 0a172b6 6b59850 38ed701 6b59850 38ed701 6b59850 38ed701 6b59850 7084eba 6b59850 67ad3dd 6b59850 719a04f 6b59850 38ed701 6b59850 7084eba 6b59850 7084eba 6b59850 7084eba 38ed701 6b59850 38ed701 67ad3dd 6b59850 67ad3dd 6b59850 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from omegaconf import OmegaConf
import gradio as gr
from dataset import init_dataset, compute_input_output_dims
from extra_features import ExtraFeatures
from demo_model import LGGMText2Graph_Demo
from analysis.spectre_utils import CrossDomainSamplingMetrics
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import torch
cfg = OmegaConf.load('./config.yaml')
hydra_path = '.'
data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb = init_dataset(cfg.dataset.name, cfg.train.batch_size, hydra_path, cfg.general.condition, cfg.model.transition)
extra_features = ExtraFeatures(cfg.model.extra_features, max_n_nodes)
input_dims, output_dims = compute_input_output_dims(data_loaders['train'], extra_features)
sampling_metrics = CrossDomainSamplingMetrics(data_loaders)
# model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device('cpu'))
model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device("cpu"))
model.init_prompt_encoder_pretrained()
def calculate_average_degree(graph):
num_nodes = graph.number_of_nodes()
num_edges = graph.number_of_edges()
return (2 * num_edges) / num_nodes if num_nodes > 0 else 0
def predict(text, num_nodes = None):
# Assuming model.generate and other processes are defined as before
graphs = model.generate_pretrained(text, int(num_nodes))
ccs = []
degs = []
images = []
for g in graphs:
ccs.append(nx.average_clustering(g))
degs.append(calculate_average_degree(g))
fig, ax = plt.subplots()
nx.draw(g, ax=ax)
fig.canvas.draw()
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close(fig)
images.append(image)
avg_deg = np.mean(degs)
avg_cc = np.mean(ccs)
return images[0], images[1], images[2], ccs[0], ccs[1], ccs[2], degs[0], degs[1], degs[2], avg_cc, avg_deg
def clear(input_text):
return None, None, None, None, None, None, None, None, None, None, None
with gr.Blocks() as demo:
gr.Markdown("## Text2Graph Generation Demo")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Input your text prompt here", placeholder="Type here...")
with gr.Column():
input_num = gr.Slider(5, 100, value=25, step = 1, label="Count", info="Number of nodes in the graph to be generated")
with gr.Column():
gr.Markdown("### Suggested Prompts")
gr.Markdown("1. Create a complex network with high clustering coefficient.\n2. Create a graph with extremely low number of triangles.\n 3. Please give me a Power Network with extremely low number of triangles but with medium level of average degree.")
with gr.Row() as output_row:
output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(3)]
with gr.Row():
output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(3)]
with gr.Row():
output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(3)]
with gr.Row():
avg_cc_text = gr.Textbox(label="Average Clustering Coefficient")
avg_deg_text = gr.Textbox(label="Average Degree")
with gr.Row():
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
# Change function is linked to the submit button
submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
input_text.submit(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
# Clear function resets the text input and clears the outputs
clear_button.click(fn=clear, inputs=[input_text], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])
demo.launch() |