Technologic101 commited on
Commit
4698077
·
1 Parent(s): be98ff0

task: working graph

Browse files
src/app.py CHANGED
@@ -28,25 +28,24 @@ async def init():
28
 
29
  @cl.on_message
30
  async def main(message: cl.Message):
31
- # Get the graph and current state from the user session
32
  graph = cl.user_session.get("graph")
 
 
33
  state = cl.user_session.get("state")
34
 
35
- # Add user message to state
36
  state["messages"].append(HumanMessage(content=message.content))
37
 
38
  # Process message through the graph
39
  result = await graph.ainvoke(state)
 
40
 
41
  # Update state with the result
42
  state["messages"].extend(result["messages"])
43
 
44
  # Extract the last assistant message for display
45
- last_message = next(
46
- (msg.content for msg in reversed(result["messages"])
47
- if isinstance(msg, SystemMessage)),
48
- "I apologize, but I couldn't process your request."
49
- )
50
 
51
  # Send response to user
52
  await cl.Message(content=last_message).send()
 
28
 
29
  @cl.on_message
30
  async def main(message: cl.Message):
31
+ # Get the graph from the user session
32
  graph = cl.user_session.get("graph")
33
+
34
+ # Get current state
35
  state = cl.user_session.get("state")
36
 
37
+ # Add the new user message to the state
38
  state["messages"].append(HumanMessage(content=message.content))
39
 
40
  # Process message through the graph
41
  result = await graph.ainvoke(state)
42
+ print("Here's the result: ", result)
43
 
44
  # Update state with the result
45
  state["messages"].extend(result["messages"])
46
 
47
  # Extract the last assistant message for display
48
+ last_message = result["messages"][-1].content
 
 
 
 
49
 
50
  # Send response to user
51
  await cl.Message(content=last_message).send()
src/graph.ipynb ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from typing import Annotated\n",
10
+ "from typing_extensions import TypedDict\n",
11
+ "from langgraph.graph.message import add_messages\n",
12
+ "from langgraph.prebuilt import create_react_agent\n",
13
+ "from langchain.tools.render import format_tool_to_openai_function\n",
14
+ "from nodes.design_rag import DesignRAG\n",
15
+ "from langchain_openai import ChatOpenAI\n",
16
+ "from langgraph.prebuilt import ToolNode\n",
17
+ "from tools.design_retriever import design_retriever_tool\n",
18
+ "\n",
19
+ "class State(TypedDict):\n",
20
+ " # Messages have the type \"list\". The `add_messages` function\n",
21
+ " # in the annotation defines how this state key should be updated\n",
22
+ " # (in this case, it appends messages to the list, rather than overwriting them)\n",
23
+ " messages: Annotated[list, add_messages]\n",
24
+ "\n",
25
+ "model = ChatOpenAI(model=\"gpt-4o-mini\", temperature=0)\n",
26
+ "\n",
27
+ "tools = [\n",
28
+ " design_retriever_tool\n",
29
+ "]\n",
30
+ "\n",
31
+ "tool_node = ToolNode(tools=tools)\n",
32
+ "\n",
33
+ "model_with_tools = model.bind_tools(tools)\n",
34
+ "\n"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 2,
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "data": {
44
+ "text/plain": [
45
+ "[{'name': 'design_retriever_tool',\n",
46
+ " 'args': {'state': {'messages': [{'content': 'Can you show me two designs with a comic book style?',\n",
47
+ " 'type': 'human'}]},\n",
48
+ " 'num_examples': 2},\n",
49
+ " 'id': 'call_6GZ0WpBwYGzmioC0p9IyyZ7h',\n",
50
+ " 'type': 'tool_call'}]"
51
+ ]
52
+ },
53
+ "execution_count": 2,
54
+ "metadata": {},
55
+ "output_type": "execute_result"
56
+ }
57
+ ],
58
+ "source": [
59
+ "model_with_tools.invoke(\"Can you show me two designs with a comic book style?\").tool_calls"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 3,
65
+ "metadata": {},
66
+ "outputs": [
67
+ {
68
+ "data": {
69
+ "text/plain": [
70
+ "<coroutine object RunnableCallable.ainvoke at 0x110215fc0>"
71
+ ]
72
+ },
73
+ "execution_count": 3,
74
+ "metadata": {},
75
+ "output_type": "execute_result"
76
+ }
77
+ ],
78
+ "source": [
79
+ "tool_node.ainvoke({\"messages\": [model_with_tools.invoke(\"Can you show me a design with a comic book style?\")]})"
80
+ ]
81
+ }
82
+ ],
83
+ "metadata": {
84
+ "kernelspec": {
85
+ "display_name": ".venv",
86
+ "language": "python",
87
+ "name": "python3"
88
+ },
89
+ "language_info": {
90
+ "codemirror_mode": {
91
+ "name": "ipython",
92
+ "version": 3
93
+ },
94
+ "file_extension": ".py",
95
+ "mimetype": "text/x-python",
96
+ "name": "python",
97
+ "nbconvert_exporter": "python",
98
+ "pygments_lexer": "ipython3",
99
+ "version": "3.11.11"
100
+ }
101
+ },
102
+ "nbformat": 4,
103
+ "nbformat_minor": 2
104
+ }
src/graph.py CHANGED
@@ -20,5 +20,8 @@ tools = [
20
 
21
  model_with_tools = model.bind_tools(tools)
22
 
23
- graph = create_react_agent(model_with_tools, tools=tools)
 
 
 
24
 
 
20
 
21
  model_with_tools = model.bind_tools(tools)
22
 
23
+ graph = create_react_agent(
24
+ model_with_tools,
25
+ tools=tools
26
+ )
27
 
src/nodes/design_rag.py CHANGED
@@ -29,7 +29,7 @@ class DesignRAG:
29
  # Create retriever with tracing
30
  self.retriever = self.vector_store.as_retriever(
31
  search_type="similarity",
32
- search_kwargs={"k": 1},
33
  tags=["design_retriever"] # Add tags for tracing
34
  )
35
 
@@ -50,25 +50,19 @@ class DesignRAG:
50
  # Load all metadata files
51
  for design_dir in designs_dir.glob("**/metadata.json"):
52
  try:
 
53
  with open(design_dir, "r") as f:
54
  metadata = json.load(f)
55
 
56
  # Create document text from metadata with safe gets
57
  text = f"""
58
- Design {metadata.get('id', 'unknown')}:
59
- Description: {metadata.get('description', 'No description available')}
 
60
  Categories: {', '.join(metadata.get('categories', []))}
61
  Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
 
62
  """
63
-
64
- # Load associated CSS
65
- '''
66
- css_path = design_dir.parent / "style.css"
67
- if css_path.exists():
68
- with open(css_path, "r") as f:
69
- css = f.read()
70
- text += f"\nCSS:\n{css}"
71
- '''
72
 
73
  # Create Document object with minimal metadata
74
  documents.append(
@@ -76,12 +70,14 @@ class DesignRAG:
76
  page_content=text.strip(),
77
  metadata={
78
  "id": metadata.get('id', 'unknown'),
79
- "path": str(design_dir.parent)
 
 
80
  }
81
  )
82
  )
83
  except Exception as e:
84
- print(f"Error processing design {design_dir}: {e}")
85
  continue
86
 
87
  if not documents:
@@ -140,17 +136,48 @@ class DesignRAG:
140
  docs = self.retriever.get_relevant_documents(
141
  query_response.content,
142
  k=num_examples,
143
- callbacks=[ConsoleCallbackHandler()] # Use standard callback instead
144
  )
145
 
146
- # Format examples
147
  examples = []
148
  for doc in docs:
149
  design_id = doc.metadata.get("id", "unknown")
 
 
 
 
150
  content_lines = doc.page_content.strip().split("\n")
151
- examples.append(
152
- "\n".join(line.strip() for line in content_lines if line.strip()) +
153
- f"\nURL: https://csszengarden.com/{design_id}"
154
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- return "\n\n".join(examples)
 
29
  # Create retriever with tracing
30
  self.retriever = self.vector_store.as_retriever(
31
  search_type="similarity",
32
+ search_kwargs={"k": 4},
33
  tags=["design_retriever"] # Add tags for tracing
34
  )
35
 
 
50
  # Load all metadata files
51
  for design_dir in designs_dir.glob("**/metadata.json"):
52
  try:
53
+ print(f"Processing design: {design_dir.parent.name}")
54
  with open(design_dir, "r") as f:
55
  metadata = json.load(f)
56
 
57
  # Create document text from metadata with safe gets
58
  text = f"""
59
+ Title: {metadata.get('title', 'Untitled')}
60
+ Author: {metadata.get('author', 'Unknown')}
61
+ Description: {metadata.get('description', {}).get('summary', 'No description available')}
62
  Categories: {', '.join(metadata.get('categories', []))}
63
  Visual Characteristics: {', '.join(metadata.get('visual_characteristics', []))}
64
+ Artistic Style: {metadata.get('artistic_context', 'None Provided').get('style_influences', "None specified")}
65
  """
 
 
 
 
 
 
 
 
 
66
 
67
  # Create Document object with minimal metadata
68
  documents.append(
 
70
  page_content=text.strip(),
71
  metadata={
72
  "id": metadata.get('id', 'unknown'),
73
+ "path": str(design_dir.parent),
74
+ "title": metadata.get('title', 'Untitled'),
75
+ "author": metadata.get('author', 'Unknown')
76
  }
77
  )
78
  )
79
  except Exception as e:
80
+ print(f"Error processing design {design_dir.parent.name}: {str(e)}")
81
  continue
82
 
83
  if not documents:
 
136
  docs = self.retriever.get_relevant_documents(
137
  query_response.content,
138
  k=num_examples,
139
+ callbacks=[ConsoleCallbackHandler()]
140
  )
141
 
142
+ # Format examples with improved readability
143
  examples = []
144
  for doc in docs:
145
  design_id = doc.metadata.get("id", "unknown")
146
+ title = doc.metadata.get("title", "Untitled")
147
+ author = doc.metadata.get("author", "Unknown")
148
+
149
+ # Parse the content into sections
150
  content_lines = doc.page_content.strip().split("\n")
151
+ sections = {}
152
+ current_section = None
153
+
154
+ for line in content_lines:
155
+ line = line.strip()
156
+ if not line:
157
+ continue
158
+ if ":" in line:
159
+ current_section, value = line.split(":", 1)
160
+ sections[current_section.strip()] = value.strip()
161
+
162
+ # Format the example with clear sections
163
+ example = f"""
164
+ Design: {title}
165
+ By: {author}
166
+
167
+ Description:
168
+ {sections.get('Description', 'No description available')}
169
+
170
+ Categories:
171
+ {sections.get('Categories', 'No categories available')}
172
+
173
+ Visual Characteristics:
174
+ {sections.get('Visual Characteristics', 'No characteristics available')}
175
+
176
+ Artistic Style:
177
+ {sections.get('Artistic Style', 'No style information available')}
178
+
179
+ View at: https://csszengarden.com/{design_id}
180
+ """
181
+ examples.append(example.strip())
182
 
183
+ return "\n\n" + "="*50 + "\n\n".join(examples)
src/tools/design_retriever.py CHANGED
@@ -1,10 +1,16 @@
1
  from nodes.design_rag import DesignRAG
2
  from langgraph.graph import MessagesState
 
3
 
4
- def design_retriever_tool(state: MessagesState, num_examples: int = 2):
5
  """
6
  Retrieves similar designs based on style requirements
7
  Name: query_similar_designs
8
  """
9
- return DesignRAG.query_similar_designs(state["messages"], num_examples)
 
 
 
 
 
10
 
 
1
  from nodes.design_rag import DesignRAG
2
  from langgraph.graph import MessagesState
3
+ from langchain_core.messages import SystemMessage
4
 
5
+ async def design_retriever_tool(state: MessagesState, num_examples: int = 2):
6
  """
7
  Retrieves similar designs based on style requirements
8
  Name: query_similar_designs
9
  """
10
+ rag = DesignRAG() # Create instance
11
+
12
+ result = await rag.query_similar_designs(state["messages"], num_examples)
13
+ print("Here's the result: ", result)
14
+
15
+ return SystemMessage(content=result)
16