Spaces:
Runtime error
Runtime error
Commit
·
4698077
1
Parent(s):
be98ff0
task: working graph
Browse files- src/app.py +6 -7
- src/graph.ipynb +104 -0
- src/graph.py +4 -1
- src/nodes/design_rag.py +48 -21
- src/tools/design_retriever.py +8 -2
src/app.py
CHANGED
@@ -28,25 +28,24 @@ async def init():
|
|
28 |
|
29 |
@cl.on_message
|
30 |
async def main(message: cl.Message):
|
31 |
-
# Get the graph
|
32 |
graph = cl.user_session.get("graph")
|
|
|
|
|
33 |
state = cl.user_session.get("state")
|
34 |
|
35 |
-
# Add user message to state
|
36 |
state["messages"].append(HumanMessage(content=message.content))
|
37 |
|
38 |
# Process message through the graph
|
39 |
result = await graph.ainvoke(state)
|
|
|
40 |
|
41 |
# Update state with the result
|
42 |
state["messages"].extend(result["messages"])
|
43 |
|
44 |
# Extract the last assistant message for display
|
45 |
-
last_message =
|
46 |
-
(msg.content for msg in reversed(result["messages"])
|
47 |
-
if isinstance(msg, SystemMessage)),
|
48 |
-
"I apologize, but I couldn't process your request."
|
49 |
-
)
|
50 |
|
51 |
# Send response to user
|
52 |
await cl.Message(content=last_message).send()
|
|
|
28 |
|
29 |
@cl.on_message
|
30 |
async def main(message: cl.Message):
|
31 |
+
# Get the graph from the user session
|
32 |
graph = cl.user_session.get("graph")
|
33 |
+
|
34 |
+
# Get current state
|
35 |
state = cl.user_session.get("state")
|
36 |
|
37 |
+
# Add the new user message to the state
|
38 |
state["messages"].append(HumanMessage(content=message.content))
|
39 |
|
40 |
# Process message through the graph
|
41 |
result = await graph.ainvoke(state)
|
42 |
+
print("Here's the result: ", result)
|
43 |
|
44 |
# Update state with the result
|
45 |
state["messages"].extend(result["messages"])
|
46 |
|
47 |
# Extract the last assistant message for display
|
48 |
+
last_message = result["messages"][-1].content
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# Send response to user
|
51 |
await cl.Message(content=last_message).send()
|
src/graph.ipynb
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from typing import Annotated\n",
|
10 |
+
"from typing_extensions import TypedDict\n",
|
11 |
+
"from langgraph.graph.message import add_messages\n",
|
12 |
+
"from langgraph.prebuilt import create_react_agent\n",
|
13 |
+
"from langchain.tools.render import format_tool_to_openai_function\n",
|
14 |
+
"from nodes.design_rag import DesignRAG\n",
|
15 |
+
"from langchain_openai import ChatOpenAI\n",
|
16 |
+
"from langgraph.prebuilt import ToolNode\n",
|
17 |
+
"from tools.design_retriever import design_retriever_tool\n",
|
18 |
+
"\n",
|
19 |
+
"class State(TypedDict):\n",
|
20 |
+
" # Messages have the type \"list\". The `add_messages` function\n",
|
21 |
+
" # in the annotation defines how this state key should be updated\n",
|
22 |
+
" # (in this case, it appends messages to the list, rather than overwriting them)\n",
|
23 |
+
" messages: Annotated[list, add_messages]\n",
|
24 |
+
"\n",
|
25 |
+
"model = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
|
26 |
+
"\n",
|
27 |
+
"tools = [\n",
|
28 |
+
" design_retriever_tool\n",
|
29 |
+
"]\n",
|
30 |
+
"\n",
|
31 |
+
"tool_node = ToolNode(tools=tools)\n",
|
32 |
+
"\n",
|
33 |
+
"model_with_tools = model.bind_tools(tools)\n",
|
34 |
+
"\n"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"execution_count": 2,
|
40 |
+
"metadata": {},
|
41 |
+
"outputs": [
|
42 |
+
{
|
43 |
+
"data": {
|
44 |
+
"text/plain": [
|
45 |
+
"[{'name': 'design_retriever_tool',\n",
|
46 |
+
" 'args': {'state': {'messages': [{'content': 'Can you show me two designs with a comic book style?',\n",
|
47 |
+
" 'type': 'human'}]},\n",
|
48 |
+
" 'num_examples': 2},\n",
|
49 |
+
" 'id': 'call_6GZ0WpBwYGzmioC0p9IyyZ7h',\n",
|
50 |
+
" 'type': 'tool_call'}]"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
"execution_count": 2,
|
54 |
+
"metadata": {},
|
55 |
+
"output_type": "execute_result"
|
56 |
+
}
|
57 |
+
],
|
58 |
+
"source": [
|
59 |
+
"model_with_tools.invoke(\"Can you show me two designs with a comic book style?\").tool_calls"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 3,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"data": {
|
69 |
+
"text/plain": [
|
70 |
+
"<coroutine object RunnableCallable.ainvoke at 0x110215fc0>"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
"execution_count": 3,
|
74 |
+
"metadata": {},
|
75 |
+
"output_type": "execute_result"
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"source": [
|
79 |
+
"tool_node.ainvoke({\"messages\": [model_with_tools.invoke(\"Can you show me a design with a comic book style?\")]})"
|
80 |
+
]
|
81 |
+
}
|
82 |
+
],
|
83 |
+
"metadata": {
|
84 |
+
"kernelspec": {
|
85 |
+
"display_name": ".venv",
|
86 |
+
"language": "python",
|
87 |
+
"name": "python3"
|
88 |
+
},
|
89 |
+
"language_info": {
|
90 |
+
"codemirror_mode": {
|
91 |
+
"name": "ipython",
|
92 |
+
"version": 3
|
93 |
+
},
|
94 |
+
"file_extension": ".py",
|
95 |
+
"mimetype": "text/x-python",
|
96 |
+
"name": "python",
|
97 |
+
"nbconvert_exporter": "python",
|
98 |
+
"pygments_lexer": "ipython3",
|
99 |
+
"version": "3.11.11"
|
100 |
+
}
|
101 |
+
},
|
102 |
+
"nbformat": 4,
|
103 |
+
"nbformat_minor": 2
|
104 |
+
}
|
src/graph.py
CHANGED
@@ -20,5 +20,8 @@ tools = [
|
|
20 |
|
21 |
model_with_tools = model.bind_tools(tools)
|
22 |
|
23 |
-
graph = create_react_agent(
|
|
|
|
|
|
|
24 |
|
|
|
20 |
|
21 |
model_with_tools = model.bind_tools(tools)
|
22 |
|
23 |
+
graph = create_react_agent(
|
24 |
+
model_with_tools,
|
25 |
+
tools=tools
|
26 |
+
)
|
27 |
|
src/nodes/design_rag.py
CHANGED
@@ -29,7 +29,7 @@ class DesignRAG:
|
|
29 |
# Create retriever with tracing
|
30 |
self.retriever = self.vector_store.as_retriever(
|
31 |
search_type="similarity",
|
32 |
-
search_kwargs={"k":
|
33 |
tags=["design_retriever"] # Add tags for tracing
|
34 |
)
|
35 |
|
@@ -50,25 +50,19 @@ class DesignRAG:
|
|
50 |
# Load all metadata files
|
51 |
for design_dir in designs_dir.glob("**/metadata.json"):
|
52 |
try:
|
|
|
53 |
with open(design_dir, "r") as f:
|
54 |
metadata = json.load(f)
|
55 |
|
56 |
# Create document text from metadata with safe gets
|
57 |
text = f"""
|
58 |
-
|
59 |
-
|
|
|
60 |
Categories: {', '.join(metadata.get('categories', []))}
|
61 |
Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
|
|
|
62 |
"""
|
63 |
-
|
64 |
-
# Load associated CSS
|
65 |
-
'''
|
66 |
-
css_path = design_dir.parent / "style.css"
|
67 |
-
if css_path.exists():
|
68 |
-
with open(css_path, "r") as f:
|
69 |
-
css = f.read()
|
70 |
-
text += f"\nCSS:\n{css}"
|
71 |
-
'''
|
72 |
|
73 |
# Create Document object with minimal metadata
|
74 |
documents.append(
|
@@ -76,12 +70,14 @@ class DesignRAG:
|
|
76 |
page_content=text.strip(),
|
77 |
metadata={
|
78 |
"id": metadata.get('id', 'unknown'),
|
79 |
-
"path": str(design_dir.parent)
|
|
|
|
|
80 |
}
|
81 |
)
|
82 |
)
|
83 |
except Exception as e:
|
84 |
-
print(f"Error processing design {design_dir}: {e}")
|
85 |
continue
|
86 |
|
87 |
if not documents:
|
@@ -140,17 +136,48 @@ class DesignRAG:
|
|
140 |
docs = self.retriever.get_relevant_documents(
|
141 |
query_response.content,
|
142 |
k=num_examples,
|
143 |
-
callbacks=[ConsoleCallbackHandler()]
|
144 |
)
|
145 |
|
146 |
-
# Format examples
|
147 |
examples = []
|
148 |
for doc in docs:
|
149 |
design_id = doc.metadata.get("id", "unknown")
|
|
|
|
|
|
|
|
|
150 |
content_lines = doc.page_content.strip().split("\n")
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
-
return "\n\n".join(examples)
|
|
|
29 |
# Create retriever with tracing
|
30 |
self.retriever = self.vector_store.as_retriever(
|
31 |
search_type="similarity",
|
32 |
+
search_kwargs={"k": 4},
|
33 |
tags=["design_retriever"] # Add tags for tracing
|
34 |
)
|
35 |
|
|
|
50 |
# Load all metadata files
|
51 |
for design_dir in designs_dir.glob("**/metadata.json"):
|
52 |
try:
|
53 |
+
print(f"Processing design: {design_dir.parent.name}")
|
54 |
with open(design_dir, "r") as f:
|
55 |
metadata = json.load(f)
|
56 |
|
57 |
# Create document text from metadata with safe gets
|
58 |
text = f"""
|
59 |
+
Title: {metadata.get('title', 'Untitled')}
|
60 |
+
Author: {metadata.get('author', 'Unknown')}
|
61 |
+
Description: {metadata.get('description', {}).get('summary', 'No description available')}
|
62 |
Categories: {', '.join(metadata.get('categories', []))}
|
63 |
Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
|
64 |
+
Artistic Style: {metadata.get('artistic_context', 'None Provided').get('style_influences', "None specified")}
|
65 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
# Create Document object with minimal metadata
|
68 |
documents.append(
|
|
|
70 |
page_content=text.strip(),
|
71 |
metadata={
|
72 |
"id": metadata.get('id', 'unknown'),
|
73 |
+
"path": str(design_dir.parent),
|
74 |
+
"title": metadata.get('title', 'Untitled'),
|
75 |
+
"author": metadata.get('author', 'Unknown')
|
76 |
}
|
77 |
)
|
78 |
)
|
79 |
except Exception as e:
|
80 |
+
print(f"Error processing design {design_dir.parent.name}: {str(e)}")
|
81 |
continue
|
82 |
|
83 |
if not documents:
|
|
|
136 |
docs = self.retriever.get_relevant_documents(
|
137 |
query_response.content,
|
138 |
k=num_examples,
|
139 |
+
callbacks=[ConsoleCallbackHandler()]
|
140 |
)
|
141 |
|
142 |
+
# Format examples with improved readability
|
143 |
examples = []
|
144 |
for doc in docs:
|
145 |
design_id = doc.metadata.get("id", "unknown")
|
146 |
+
title = doc.metadata.get("title", "Untitled")
|
147 |
+
author = doc.metadata.get("author", "Unknown")
|
148 |
+
|
149 |
+
# Parse the content into sections
|
150 |
content_lines = doc.page_content.strip().split("\n")
|
151 |
+
sections = {}
|
152 |
+
current_section = None
|
153 |
+
|
154 |
+
for line in content_lines:
|
155 |
+
line = line.strip()
|
156 |
+
if not line:
|
157 |
+
continue
|
158 |
+
if ":" in line:
|
159 |
+
current_section, value = line.split(":", 1)
|
160 |
+
sections[current_section.strip()] = value.strip()
|
161 |
+
|
162 |
+
# Format the example with clear sections
|
163 |
+
example = f"""
|
164 |
+
Design: {title}
|
165 |
+
By: {author}
|
166 |
+
|
167 |
+
Description:
|
168 |
+
{sections.get('Description', 'No description available')}
|
169 |
+
|
170 |
+
Categories:
|
171 |
+
{sections.get('Categories', 'No categories available')}
|
172 |
+
|
173 |
+
Visual Characteristics:
|
174 |
+
{sections.get('Visual Characteristics', 'No characteristics available')}
|
175 |
+
|
176 |
+
Artistic Style:
|
177 |
+
{sections.get('Artistic Style', 'No style information available')}
|
178 |
+
|
179 |
+
View at: https://csszengarden.com/{design_id}
|
180 |
+
"""
|
181 |
+
examples.append(example.strip())
|
182 |
|
183 |
+
return "\n\n" + "="*50 + "\n\n".join(examples)
|
src/tools/design_retriever.py
CHANGED
@@ -1,10 +1,16 @@
|
|
1 |
from nodes.design_rag import DesignRAG
|
2 |
from langgraph.graph import MessagesState
|
|
|
3 |
|
4 |
-
def design_retriever_tool(state: MessagesState, num_examples: int = 2):
|
5 |
"""
|
6 |
Retrieves similar designs based on style requirements
|
7 |
Name: query_similar_designs
|
8 |
"""
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
1 |
from nodes.design_rag import DesignRAG
|
2 |
from langgraph.graph import MessagesState
|
3 |
+
from langchain_core.messages import SystemMessage
|
4 |
|
5 |
+
async def design_retriever_tool(state: MessagesState, num_examples: int = 2):
|
6 |
"""
|
7 |
Retrieves similar designs based on style requirements
|
8 |
Name: query_similar_designs
|
9 |
"""
|
10 |
+
rag = DesignRAG() # Create instance
|
11 |
+
|
12 |
+
result = await rag.query_similar_designs(state["messages"], num_examples)
|
13 |
+
print("Here's the result: ", result)
|
14 |
+
|
15 |
+
return SystemMessage(content=result)
|
16 |
|