Technologic101 commited on
Commit
125ffe9
·
1 Parent(s): bb34640

task: [wip] initial graph

Browse files
Files changed (4) hide show
  1. src/app.py +25 -38
  2. src/graph.py +11 -38
  3. src/nodes/designer.py +0 -122
  4. src/tools/design_retriever.py +9 -24
src/app.py CHANGED
@@ -1,10 +1,6 @@
1
  import chainlit as cl
2
- from langchain_openai import ChatOpenAI
3
  from langchain_core.messages import HumanMessage, SystemMessage
4
- from nodes.design_rag import DesignRAG
5
-
6
- # Initialize components
7
- design_rag = DesignRAG()
8
 
9
  # System message focused on design analysis
10
  SYSTEM_MESSAGE = """You are a helpful design assistant that finds and explains design examples.
@@ -18,51 +14,42 @@ For every user message, analyze their design preferences and requirements, consi
18
 
19
  @cl.on_chat_start
20
  async def init():
21
- # Initialize LLM with callback handler inside the Chainlit context
22
- llm = ChatOpenAI(
23
- model="gpt-4o-mini",
24
- temperature=0,
25
- streaming=True,
26
- callbacks=[cl.LangchainCallbackHandler()]
27
- )
28
-
29
- # Store the LLM in the user session
30
- cl.user_session.set("design_llm", llm)
31
 
32
- # init conversation history for each user
33
- cl.user_session.set("conversation_history", [
34
- SystemMessage(content=SYSTEM_MESSAGE)
35
- ])
 
36
 
37
  # Send welcome message
38
  await cl.Message(content="Welcome to ImagineUI! I'm here to help you design beautiful and functional user interfaces. What kind of design are you looking for?").send()
39
 
40
  @cl.on_message
41
  async def main(message: cl.Message):
42
- # Get the LLM from the user session
43
- llm = cl.user_session.get("design_llm")
44
- conversation_history = cl.user_session.get("conversation_history")
45
 
46
- # Add user message to history
47
- conversation_history.append(HumanMessage(content=message.content))
48
 
49
- # Get LLM's analysis of requirements
50
- analysis = await llm.ainvoke(conversation_history)
51
 
52
- # Get best design example based on full conversation
53
- designs = await design_rag.query_similar_designs(
54
- [msg.content for msg in conversation_history],
55
- num_examples=1
56
- )
57
-
58
- # Combine analysis with designs
59
- response = f"{analysis.content}\n\nHere is the best match from the zen garden:\n\n{designs}"
60
 
61
- # Add assistant's response to history
62
- conversation_history.append(SystemMessage(content=response))
 
 
 
 
63
 
64
  # Send response to user
65
- await cl.Message(content=response).send()
66
 
67
  if __name__ == "__main__":
68
- cl.run()
 
1
  import chainlit as cl
 
2
  from langchain_core.messages import HumanMessage, SystemMessage
3
+ from graph import graph
 
 
 
4
 
5
  # System message focused on design analysis
6
  SYSTEM_MESSAGE = """You are a helpful design assistant that finds and explains design examples.
 
14
 
15
  @cl.on_chat_start
16
  async def init():
17
+ # Store the graph in the user session
18
+ cl.user_session.set("graph", graph)
 
 
 
 
 
 
 
 
19
 
20
+ # Initialize conversation state with system message
21
+ initial_state = {
22
+ "messages": [SystemMessage(content=SYSTEM_MESSAGE)]
23
+ }
24
+ cl.user_session.set("state", initial_state)
25
 
26
  # Send welcome message
27
  await cl.Message(content="Welcome to ImagineUI! I'm here to help you design beautiful and functional user interfaces. What kind of design are you looking for?").send()
28
 
29
  @cl.on_message
30
  async def main(message: cl.Message):
31
+ # Get the graph and current state from the user session
32
+ graph = cl.user_session.get("graph")
33
+ state = cl.user_session.get("state")
34
 
35
+ # Add user message to state
36
+ state["messages"].append(HumanMessage(content=message.content))
37
 
38
+ # Process message through the graph
39
+ result = await graph.ainvoke(state)
40
 
41
+ # Update state with the result
42
+ state["messages"].extend(result["messages"])
 
 
 
 
 
 
43
 
44
+ # Extract the last assistant message for display
45
+ last_message = next(
46
+ (msg.content for msg in reversed(result["messages"])
47
+ if isinstance(msg, SystemMessage)),
48
+ "I apologize, but I couldn't process your request."
49
+ )
50
 
51
  # Send response to user
52
+ await cl.Message(content=last_message).send()
53
 
54
  if __name__ == "__main__":
55
+ cl.run()
src/graph.py CHANGED
@@ -1,10 +1,10 @@
1
  from typing import Annotated
2
  from typing_extensions import TypedDict
3
- from langgraph.graph import StateGraph, START, END
4
  from langgraph.graph.message import add_messages
5
- from langgraph.prebuilt import ToolInvoker
6
- from nodes.designer import DesignerNode
7
- from langchain.tools.render import format_tool_to_openai_function
 
8
 
9
  class State(TypedDict):
10
  # Messages have the type "list". The `add_messages` function
@@ -12,40 +12,13 @@ class State(TypedDict):
12
  # (in this case, it appends messages to the list, rather than overwriting them)
13
  messages: Annotated[list, add_messages]
14
 
15
- def create_graph():
16
- # Initialize nodes
17
- designer = DesignerNode()
18
-
19
- # Create graph
20
- graph = StateGraph(State)
21
-
22
- # Add designer node
23
- graph.add_node("designer", designer)
24
-
25
- # Create tool invoker node with designer's tools
26
- tools = designer.get_available_tools()
27
- tool_executor = ToolInvoker(tools=tools)
28
- graph.add_node("tools", tool_executor)
29
-
30
- # Add edges
31
- graph.add_edge(START, "designer")
32
-
33
- # Add conditional edges based on tool calls
34
- graph.add_conditional_edges(
35
- "designer",
36
- lambda state: "tools" if state["messages"][-1].get("tool_calls") else END,
37
- {
38
- "tools": "tools",
39
- END: END
40
- }
41
- )
42
-
43
- # After tool execution, return to designer
44
- graph.add_edge("tools", "designer")
45
-
46
- return graph.compile()
47
 
48
- # Create the graph
49
- graph = create_graph()
50
 
 
51
 
 
1
  from typing import Annotated
2
  from typing_extensions import TypedDict
 
3
  from langgraph.graph.message import add_messages
4
+ from langgraph.prebuilt import create_react_agent
5
+ from langchain_anthropic import ChatAnthropic
6
+ from tools.design_retriever import design_retriever_tool
7
+
8
 
9
  class State(TypedDict):
10
  # Messages have the type "list". The `add_messages` function
 
12
  # (in this case, it appends messages to the list, rather than overwriting them)
13
  messages: Annotated[list, add_messages]
14
 
15
+ model = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0)
16
+
17
+ tools = [
18
+ design_retriever_tool
19
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ model_with_tools = model.bind_tools(tools)
 
22
 
23
+ graph = create_react_agent(model_with_tools, tools=tools)
24
 
src/nodes/designer.py DELETED
@@ -1,122 +0,0 @@
1
- from typing import Dict, List
2
- from anthropic import AsyncAnthropic
3
- import json
4
- from langchain_core.tools import tool
5
- from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
6
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
7
- from nodes.design_rag import DesignRAG
8
-
9
- class DesignerNode:
10
- """Main conversation node for discussing design requirements and retrieving examples"""
11
-
12
- def __init__(self):
13
- self.client = AsyncAnthropic()
14
- self.rag = DesignRAG()
15
-
16
- # Define the conversation prompt
17
- self.prompt = ChatPromptTemplate.from_messages([
18
- ("system", """You are an expert design assistant helping users find design inspiration.
19
- Your goal is to understand their design needs and requirements through conversation.
20
-
21
- Guidelines:
22
- 1. Focus on understanding visual design requirements, not implementation
23
- 2. Ask clarifying questions about style, mood, and visual elements
24
- 3. When the user asks to see examples, use the retrieve_design_examples tool
25
- 4. Track both must-have requirements and nice-to-have preferences
26
- 5. When showing examples, explain how they match the requirements
27
-
28
- Available tools:
29
- - retrieve_design_examples: Find relevant design examples based on conversation
30
-
31
- When the user asks to see examples, ALWAYS use the retrieve_design_examples tool.
32
- Format tool calls using the exact function name and parameters.
33
- """),
34
- MessagesPlaceholder(variable_name="chat_history"),
35
- ("human", "{input}"),
36
- ])
37
-
38
- @tool()
39
- async def retrieve_design_examples(self, conversation: List[str], num_examples: int = 1) -> str:
40
- """
41
- Find and retrieve relevant design examples based on the conversation history.
42
-
43
- Args:
44
- conversation: List of conversation messages
45
- num_examples: Number of examples to retrieve (default: 1)
46
-
47
- Returns:
48
- String containing design examples and their details
49
- """
50
- return await self.rag.query_similar_designs(conversation, num_examples)
51
-
52
- def get_available_tools(self):
53
- """Return list of available tools"""
54
- return [self.retrieve_design_examples]
55
-
56
- async def __call__(self, state: Dict) -> Dict:
57
- """Process messages and manage design discussion"""
58
- messages = state.get("messages", [])
59
-
60
- # Convert messages to chat history format
61
- chat_history = []
62
- for msg in messages[:-1]: # Exclude the last message which is the current input
63
- if isinstance(msg, dict):
64
- role = msg.get("role", "user")
65
- content = msg.get("content", "")
66
- chat_history.append(
67
- HumanMessage(content=content) if role == "user"
68
- else AIMessage(content=content)
69
- )
70
- elif isinstance(msg, BaseMessage):
71
- chat_history.append(msg)
72
-
73
- # Get the current input message
74
- current_input = messages[-1].get("content") if isinstance(messages[-1], dict) else messages[-1].content
75
-
76
- # Get response from Claude
77
- response = await self.client.messages.create(
78
- model="claude-3-haiku-20240307",
79
- max_tokens=500,
80
- messages=[{
81
- "role": "user",
82
- "content": self.prompt.format(
83
- chat_history=chat_history,
84
- input=current_input
85
- )
86
- }]
87
- )
88
-
89
- response_text = response.content[0].text
90
-
91
- # Check if response indicates need for examples
92
- should_retrieve = (
93
- "retrieve_design_examples" in response_text or
94
- any(phrase in current_input.lower()
95
- for phrase in ["show example", "find design", "get example"])
96
- )
97
-
98
- if should_retrieve:
99
- # Create tool call message
100
- state["messages"].append({
101
- "role": "assistant",
102
- "content": response_text,
103
- "tool_calls": [{
104
- "type": "function",
105
- "function": {
106
- "name": "retrieve_design_examples",
107
- "arguments": json.dumps({
108
- "conversation": [msg.get("content", msg) if isinstance(msg, dict) else msg
109
- for msg in messages],
110
- "num_examples": 1
111
- })
112
- }
113
- }]
114
- })
115
- else:
116
- # Regular response without tool calls
117
- state["messages"].append({
118
- "role": "assistant",
119
- "content": response_text
120
- })
121
-
122
- return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/tools/design_retriever.py CHANGED
@@ -1,25 +1,10 @@
1
- from typing import Dict, Optional
2
- from langchain.tools import BaseTool
3
- from chains.design_rag import DesignRAG
4
- from pydantic import Field
5
- import json
 
 
 
 
6
 
7
- class DesignRetrieverTool(BaseTool):
8
- """Tool for retrieving similar designs based on requirements."""
9
-
10
- name: str = "design_retriever"
11
- description: str = "Retrieves similar designs based on style requirements"
12
- rag: DesignRAG = Field(description="Design RAG system for retrieving similar designs")
13
-
14
- def __init__(self, rag: DesignRAG):
15
- """Initialize the tool with a DesignRAG instance."""
16
- super().__init__(rag=rag)
17
-
18
- def _run(self, requirements: Dict, num_examples: int = 3) -> str:
19
- """Sync version - not used but required by BaseTool"""
20
- raise NotImplementedError("Use async version")
21
-
22
- async def _arun(self, requirements: Dict, num_examples: int = 3) -> str:
23
- """Retrieve similar designs based on requirements"""
24
- print(f"Retrieving {num_examples} similar designs")
25
- return await self.rag.query_similar_designs(requirements, num_examples)
 
1
+ from nodes.design_rag import DesignRAG
2
+ from langgraph.graph import MessagesState
3
+
4
+ def design_retriever_tool(state: MessagesState, num_examples: int = 2):
5
+ """
6
+ Retrieves similar designs based on style requirements
7
+ Name: query_similar_designs
8
+ """
9
+ return DesignRAG.query_similar_designs(state["messages"], num_examples)
10