File size: 5,949 Bytes
7ce5a68
 
608d7cd
71f8a1b
 
7ce5a68
608d7cd
7ce5a68
608d7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce5a68
608d7cd
7ce5a68
608d7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce5a68
608d7cd
7ce5a68
608d7cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce5a68
71f8a1b
7ce5a68
608d7cd
 
71f8a1b
 
 
7ce5a68
71f8a1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608d7cd
7ce5a68
6493d23
608d7cd
 
 
 
 
 
 
1f58b80
7ce5a68
 
 
71f8a1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from huggingface_hub import InferenceClient
import ast
import networkx as nx
import matplotlib.pyplot as plt

client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")

def sampling(num_samples, num_associations):
  outputs = ast.literal_eval(client.chat.completions.create(
                             messages=[
                                 {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
                                 "words: []\n"
                                 f"{num_samples} samples in a list"
                                 },
                                 {"role": "user",
                                  "content": f"synthesize {num_samples} random but widespread words for semantic modeling"},
                                 ],
                                 response_format={
                                     "type": "json",
                                     "value": {
                                         "properties": {
                                             "words": {"type": "array", "items": {"type": "string"}},
                                             }
                                         }
                                     },
                             stream=False,
                             max_tokens=1024,
                             temperature=0.7,
                             top_p=0.1
                             ).choices[0].get('message')['content'])

  fields = {}

  for word in outputs['words']:
    fields[word] = ast.literal_eval(client.chat.completions.create(
                  messages=[
                      {"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n'
                                                    'associations: []'
                      },
                      {"role": "user",
                      "content": f"synthesize {num_associations} associations for the word {word}"},
                  ],
                  response_format={
                                     "type": "json",
                                     "value": {
                                         "properties": {
                                             "associations": {"type": "array", "items": {"type": "string"}}
                                             }
                                         }
                                     },
                  stream=False,
                  max_tokens=2000,
                  temperature=0.7,
                  top_p=0.1
                  ).choices[0].get('message')['content'])

  triplets = []

  for cluster in fields:
    for association in fields[cluster]['associations']:
      triplets.append(ast.literal_eval(client.chat.completions.create(
                                       messages=[
                                                 {"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
                                                                               "properties: [subject, predicate, object]\n"
                                                                               "use chain-of-thought for predictions"
                                                 },
                                                 {"role": "user",
                                                  "content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"},
                                                 ],
                                                 response_format={
                                                                  "type": "json",
                                                                  "value": {
                                                                      "properties": {
                                                                                      "properties": {"type": "array", "items": {"type": "string"}}
                                                                                    }
                                                                      }
                                                                  },
                                       stream=False,
                                       max_tokens=128,
                                       temperature=0.7,
                                       top_p=0.1
                                       ).choices[0].get('message')['content']))

  G = nx.DiGraph()

  for entry in triplets:
    source, label, target = entry['properties']
    G.add_node(source, label=source)
    G.add_node(target, label=target)
    G.add_edge(source, target, label=label)

  pos = nx.spring_layout(G)
  nx.draw_networkx_nodes(G, pos, node_size=500, node_color='lightblue')

  edge_labels = nx.get_edge_attributes(G, 'label')  # Get edge labels
  nx.draw_networkx_edges(G, pos, arrowstyle='->', arrowsize=25)
  nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

  node_labels = nx.get_node_attributes(G, 'label')  # Get node labels
  nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_family="sans-serif")

  plt.axis('off')
  plt.tight_layout()

  plt.savefig('synthnet.png')
  plt.close()
  return 'synthnet.png'

demo = gr.Interface(
    inputs=[
        gr.Slider(minimum=1, maximum=256, label="Number of Samples"),
        gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"),
        ],
    fn=sampling,
    outputs=gr.Image(type="filepath"),
    title="SynthNet",
    description="Select a number of samples and associations to each sample to generate a graph.",
)

if __name__ == "__main__":
    demo.launch(share=True)