AGGILE / aggile.py
missvector's picture
Update aggile.py
c22fc23 verified
# aggile.py
class Aggile:
"""
Graph generator for plain text
"""
def __init__(self, client):
self.client = client
n = None
self.subj_prompt = f"""
extract {n} collocations describing key concepts, keywords, named entities from the provided source
"""
self.obj_prompt = """
extract 5-10 most representative collocations from the provided source that are related to the provided concept
"""
self.pred_prompt = """
define the relationship between two words: generate a verb or a phrase decribing a relationship between two entities; return a predicate for a knowledge graph triplet
"""
def _get_subj(self, text, n=10):
"""
Extract entities from the text:
- named entities
- kewords
- concepts
:text: input text (str)
:n: the number of genrated entities (int)
:return: {core_concepts: list of extracted keywords (subjects that will form triplets)} (dict)
"""
import ast
# Generate keywords from the given text using LLM
core_concepts = self.client.chat.completions.create(messages=
[
{
"role": "system",
"content": self.subj_prompt
},
{
"role": "user",
"content": text
},
],
response_format=
{
"type": "json",
"value":
{
"properties":
{
"core_concepts":
{
"type": "array",
"items":
{
"type": "string"
}
},
}
}
},
stream=False,
max_tokens=1024,
temperature=0.5,
top_p=0.1
).choices[0].get('message')['content']
return ast.literal_eval(core_concepts)
def __extract_relations(self, word, text):
import ast
"""
Extract relation for the provided concepts (subjects) based on the information from the text:
- collocations
:text: input text (str)
:concepts: the list of kewords and other key concepts extracted with aggile._get_subj (dict)
:return: {related_concepts: list of related words and collocations (objects that will form triplets)} (dict)
"""
related_concepts = self.client.chat.completions.create(messages=
[
{
"role": "system",
"content": self.obj_prompt
},
{
"role": "user",
"content": f"concept = {word}, source = {text}"
},
],
response_format=
{
"type": "json",
"value":
{
"properties":
{
"related_concepts":
{
"type": "array",
"items":
{
"type": "string"
}
},
}
}
},
stream=False,
max_tokens=512,
temperature=0.5,
top_p=0.1
).choices[0].get('message')['content']
return ast.literal_eval(related_concepts)
def _get_obj(self, text):
"""
Execute the extraction of related concepts for the list of keywords:
- generate list of objects for each object in the dictionarytract relation for the provided concepts (subjects) based on the information from the text:
:text: input text (str)
:concepts: the list of keywords and other key concepts extracted with aggile._get_subj (dict)
:return: {related_concepts: list of related words and collocations (objects that will form triplets)} (dict)
"""
# Generate list of subjects
core_concepts = self._get_subj(text, n=10)
# Get object for each subject
relations = {word: self.__extract_relations(word, text) for word in core_concepts['core_concepts']}
return relations
def __generate_predicates(self, subj, obj):
import ast
"""
Generate predicates between objects and subjects
:subj: one generated subject from core_concepts (str)
:obj: one generated object from relations (str)
:text: input text (str)
:return: one relevant predicate to form triplets (str)
"""
predicate = self.client.chat.completions.create(messages=
[
{
"role": "system",
"content": self.pred_prompt
},
{
"role": "user",
"content": f"what is the relationship between {subj} and {obj}? return a predicate only"
},
],
response_format=
{
"type": "json",
"value":
{
"properties":
{
"predicate":
{
"type": "string"
},
}
}
},
stream=False,
max_tokens=512,
temperature=0.5,
top_p=0.1
).choices[0].get('message')['content']
return ast.literal_eval(predicate)['predicate'] # Return predicate only, not the whole dictionary
def form_triples(self, text):
"""
:text: input text (str) if from_string=True
"""
# Generate objects from text
relations = self._get_obj(text)
# Placeholder for triplets
triplets = dict()
# Form triplets for each subject
for subj in relations:
# Placeholder for the current subject
triplets[subj] = list()
# For each object generated for this subject:
for obj in relations[subj]['related_concepts']:
# Create placeholder with the triplet structure "subject-predicate-object"
temp = {'subject': subj, 'predicate': '', 'object': ''}
# Save the object to the triplet
temp['object'] = obj
# Generate predicate between the current object and the current subject
temp['predicate'] = self.__generate_predicates(subj, obj)
# Hallucincation check: if object and subjects are the same entities, do not append them to the list of triplets
if temp['subject'] != temp['object']:
# Otherwise, append the triplet
triplets[subj].append(temp)
return triplets
class Graph:
def __init__(self, triplets):
self.triplets = triplets
def build_graph(self):
import plotly.graph_objects as go
import networkx as nx
from collections import Counter
import random
# Prepare nodes and edges
nodes = set()
edges = []
# Extract noded and edges from the set of triplets
for key, values in self.triplets.items():
for rel in values:
nodes.add(rel['subject'])
nodes.add(rel['object'])
edges.append((rel['subject'], rel['object'], rel['predicate']))
# Create a networkx graph
G = nx.Graph()
# Add nodes and edges to the graph
for edge in edges:
G.add_edge(edge[0], edge[1], label=edge[2])
# Generate positions for nodes using force-directed layout with more space
pos = nx.spring_layout(G, seed=42) # Increasing k for more spacing
# Extract node and edge data for Plotly
node_x = [pos[node][0] for node in G.nodes()]
node_y = [pos[node][1] for node in G.nodes()]
node_labels = list(G.nodes())
# Count connections
node_degrees = Counter([node for edge in edges for node in edge[:2]])
# Assign distinct colors for each predicate (use a set to avoid duplicates)
unique_predicates = list(set([edge[2] for edge in edges]))
predicate_colors = {predicate: f'rgba({random.randint(0,255)},{random.randint(0,255)},{random.randint(0,255)},1)'
for predicate in unique_predicates}
# Plotly data for edges
edge_x = []
edge_y = []
for edge in edges:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_x += [x0, x1, None]
edge_y += [y0, y1, None]
# Create the figure
fig = go.Figure()
# Add edges
fig.add_trace(go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='text',
mode='lines'
))
# Add nodes with uniform size and labels
fig.add_trace(go.Scatter(
x=node_x, y=node_y,
mode='markers+text',
marker=dict(
size=25, # Uniform node size for all nodes
color=[node_degrees[node] for node in node_labels],
#colorscale='Viridis',
colorbar=dict(title='Connections')
),
text=node_labels,
hoverinfo='text',
textposition='top center',
textfont=dict(size=13, weight="bold")
))
# Add predicate labels near the nodes with black text
for edge in edges:
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
predicate_label = edge[2]
# Calculate the midpoint of the edge and add small offsets to create spacing
mid_x = (x0 + x1) / 2
mid_y = (y0 + y1) / 2
# Add the label near the midpoint of the edge with black text
fig.add_trace(go.Scatter(
x=[mid_x], y=[mid_y],
mode='text',
text=[predicate_label],
textposition='middle center',
showlegend=False,
textfont=dict(size=10)
))
# Update layout
fig.update_layout(
showlegend=False,
margin=dict(l=0, r=0, t=0, b=0),
xaxis=dict(showgrid=False, zeroline=False),
yaxis=dict(showgrid=False, zeroline=False),
title="Force-Directed Graph with Predicate Labels on Nodes"
)
# Save the figure as an HTML file
#fig.write_html("graph_with_predicates.html")
return fig