File size: 3,779 Bytes
6b59850
 
 
 
 
 
 
 
 
 
38ed701
6b59850
 
 
 
 
 
 
 
 
 
 
 
 
 
bad805a
6b59850
38ed701
6b59850
 
 
 
 
 
 
 
 
38ed701
6b59850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38ed701
 
 
6b59850
38ed701
6b59850
 
 
 
 
 
 
 
 
 
 
 
 
 
38ed701
6b59850
 
 
 
 
 
 
38ed701
 
 
 
6b59850
 
 
 
 
 
38ed701
6b59850
 
38ed701
6b59850
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from omegaconf import OmegaConf
import gradio as gr

from dataset import init_dataset, compute_input_output_dims
from extra_features import ExtraFeatures
from demo_model import LGGMText2Graph_Demo
from analysis.spectre_utils import CrossDomainSamplingMetrics
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import torch


cfg = OmegaConf.load('./config.yaml')
hydra_path = '.'


data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb = init_dataset(cfg.dataset.name, cfg.train.batch_size, hydra_path, cfg.general.condition, cfg.model.transition)

extra_features = ExtraFeatures(cfg.model.extra_features, max_n_nodes)

input_dims, output_dims = compute_input_output_dims(data_loaders['train'], extra_features)

sampling_metrics = CrossDomainSamplingMetrics(data_loaders)

model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device("cpu"))

model.init_prompt_encoder_pretrained()

def calculate_average_degree(graph):
    num_nodes = graph.number_of_nodes()
    num_edges = graph.number_of_edges()
    return (2 * num_edges) / num_nodes if num_nodes > 0 else 0


def predict(text, num_nodes = None):
    # Assuming model.generate and other processes are defined as before
    graphs = model.generate_pretrained(text, int(num_nodes))
    ccs = []
    degs = []
    images = []

    for g in graphs:
        ccs.append(nx.average_clustering(g))
        degs.append(calculate_average_degree(g))

        fig, ax = plt.subplots()
        nx.draw(g, ax=ax)
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close(fig)
        
        images.append(image)
    
    avg_deg = np.mean(degs)
    avg_cc = np.mean(ccs)

    return images[0], images[1], images[2], images[3], images[4], ccs[0], ccs[1], ccs[2], ccs[3], ccs[4], degs[0], degs[1], degs[2], degs[3], degs[4], avg_cc, avg_deg

def clear(input_text):
    return None, None


with gr.Blocks() as demo:
    gr.Markdown("## Text2Graph Generation Demo")
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label="Input your text prompt here", placeholder="Type here...")
        with gr.Column():
            input_num = gr.Slider(5, 200, value=10, label="Count", info="Number of nodes in the graph to be generated")
        with gr.Column():
            gr.Markdown("### Suggested Prompts")
            gr.Markdown("1. Create a complex network with high clustering coefficient.\n2. Create a graph with extremely low number of triangles.\n 3. Please give me a Power Network with extremely low number of triangles but with medium level of average degree.")

    with gr.Row() as output_row:
        output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(5)]
    with gr.Row():
        output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(5)]
    with gr.Row():
        output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(5)]
    
    with gr.Row():
        avg_cc_text = gr.Textbox(label="Average Clustering Coefficient")
        avg_deg_text = gr.Textbox(label="Average Degree")

    with gr.Row():
        submit_button = gr.Button("Submit")
        clear_button = gr.Button("Clear")

    # Change function is linked to the submit button
    submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])

    # Clear function resets the text input and clears the outputs
    clear_button.click(fn=clear, inputs=input_text, outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text])

demo.launch()