Spaces:
Running
Running
File size: 5,949 Bytes
7ce5a68 608d7cd 71f8a1b 7ce5a68 608d7cd 7ce5a68 608d7cd 7ce5a68 608d7cd 7ce5a68 608d7cd 7ce5a68 608d7cd 7ce5a68 608d7cd 7ce5a68 71f8a1b 7ce5a68 608d7cd 71f8a1b 7ce5a68 71f8a1b 608d7cd 7ce5a68 6493d23 608d7cd 1f58b80 7ce5a68 71f8a1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
from huggingface_hub import InferenceClient
import ast
import networkx as nx
import matplotlib.pyplot as plt
client = InferenceClient("Qwen/Qwen2.5-72B-Instruct")
def sampling(num_samples, num_associations):
outputs = ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
"words: []\n"
f"{num_samples} samples in a list"
},
{"role": "user",
"content": f"synthesize {num_samples} random but widespread words for semantic modeling"},
],
response_format={
"type": "json",
"value": {
"properties": {
"words": {"type": "array", "items": {"type": "string"}},
}
}
},
stream=False,
max_tokens=1024,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content'])
fields = {}
for word in outputs['words']:
fields[word] = ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": 'generate one json object, no explanation or additional text, use the following structure:\n'
'associations: []'
},
{"role": "user",
"content": f"synthesize {num_associations} associations for the word {word}"},
],
response_format={
"type": "json",
"value": {
"properties": {
"associations": {"type": "array", "items": {"type": "string"}}
}
}
},
stream=False,
max_tokens=2000,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content'])
triplets = []
for cluster in fields:
for association in fields[cluster]['associations']:
triplets.append(ast.literal_eval(client.chat.completions.create(
messages=[
{"role": "system", "content": "generate one json object, no explanation or additional text, use the following structure:\n"
"properties: [subject, predicate, object]\n"
"use chain-of-thought for predictions"
},
{"role": "user",
"content": f"form triplet based on semantics: generate predicate between the word {cluster} (subject) and the word {association} (object); return list with [subject, predicate, object]"},
],
response_format={
"type": "json",
"value": {
"properties": {
"properties": {"type": "array", "items": {"type": "string"}}
}
}
},
stream=False,
max_tokens=128,
temperature=0.7,
top_p=0.1
).choices[0].get('message')['content']))
G = nx.DiGraph()
for entry in triplets:
source, label, target = entry['properties']
G.add_node(source, label=source)
G.add_node(target, label=target)
G.add_edge(source, target, label=label)
pos = nx.spring_layout(G)
nx.draw_networkx_nodes(G, pos, node_size=500, node_color='lightblue')
edge_labels = nx.get_edge_attributes(G, 'label') # Get edge labels
nx.draw_networkx_edges(G, pos, arrowstyle='->', arrowsize=25)
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
node_labels = nx.get_node_attributes(G, 'label') # Get node labels
nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=8, font_family="sans-serif")
plt.axis('off')
plt.tight_layout()
plt.savefig('synthnet.png')
plt.close()
return 'synthnet.png'
demo = gr.Interface(
inputs=[
gr.Slider(minimum=1, maximum=256, label="Number of Samples"),
gr.Slider(minimum=1, maximum=256, label="Number of Associations to each Sample"),
],
fn=sampling,
outputs=gr.Image(type="filepath"),
title="SynthNet",
description="Select a number of samples and associations to each sample to generate a graph.",
)
if __name__ == "__main__":
demo.launch(share=True) |